diff --git a/README.md b/README.md index 6b92a179..2a8ea22e 100644 --- a/README.md +++ b/README.md @@ -150,6 +150,53 @@ The test suite uses Python's standard library, the built-in `unittest` library, |/tests | Source for unit tests | |/utils | Source for utilities shared by the examples and unit tests | +### Customization +* When working with custom search commands such as Custom Streaming Commands or Custom Generating Commands, We may need to add new fields to the records based on certain conditions. +* Structural changes like this may not be preserved. +* Make sure to use ``add_field(record, fieldname, value)`` method from SearchCommand to add a new field and value to the record. +* ___Note:__ Usage of ``add_field`` method is completely optional, if you are not facing any issues with field retention._ + +Do +```python +class CustomStreamingCommand(StreamingCommand): + def stream(self, records): + for index, record in enumerate(records): + if index % 1 == 0: + self.add_field(record, "odd_record", "true") + yield record +``` + +Don't +```python +class CustomStreamingCommand(StreamingCommand): + def stream(self, records): + for index, record in enumerate(records): + if index % 1 == 0: + record["odd_record"] = "true" + yield record +``` +### Customization for Generating Custom Search Command +* Generating Custom Search Command is used to generate events using SDK code. +* Make sure to use ``gen_record()`` method from SearchCommand to add a new record and pass event data as a key=value pair separated by , (mentioned in below example). + +Do +```python +@Configuration() + class GeneratorTest(GeneratingCommand): + def generate(self): + yield self.gen_record(_time=time.time(), one=1) + yield self.gen_record(_time=time.time(), two=2) +``` + +Don't +```python +@Configuration() + class GeneratorTest(GeneratingCommand): + def generate(self): + yield {'_time': time.time(), 'one': 1} + yield {'_time': time.time(), 'two': 2} +``` + ### Changelog The [CHANGELOG](CHANGELOG.md) contains a description of changes for each version of the SDK. For the latest version, see the [CHANGELOG.md](https://github.com/splunk/splunk-sdk-python/blob/master/CHANGELOG.md) on GitHub. diff --git a/splunklib/searchcommands/generating_command.py b/splunklib/searchcommands/generating_command.py index e766effb..6a75d2c2 100644 --- a/splunklib/searchcommands/generating_command.py +++ b/splunklib/searchcommands/generating_command.py @@ -213,13 +213,20 @@ def _execute(self, ifile, process): def _execute_chunk_v2(self, process, chunk): count = 0 + records = [] for row in process: - self._record_writer.write_record(row) + records.append(row) count += 1 if count == self._record_writer._maxresultrows: - self._finished = False - return - self._finished = True + break + + for row in records: + self._record_writer.write_record(row) + + if count == self._record_writer._maxresultrows: + self._finished = False + else: + self._finished = True def process(self, argv=sys.argv, ifile=sys.stdin, ofile=sys.stdout, allow_empty_input=True): """ Process data. diff --git a/splunklib/searchcommands/internals.py b/splunklib/searchcommands/internals.py index 85f9e0fe..fa32f0b1 100644 --- a/splunklib/searchcommands/internals.py +++ b/splunklib/searchcommands/internals.py @@ -508,6 +508,7 @@ def __init__(self, ofile, maxresultrows=None): self._chunk_count = 0 self._pending_record_count = 0 self._committed_record_count = 0 + self.custom_fields = set() @property def is_flushed(self): @@ -572,6 +573,7 @@ def write_record(self, record): def write_records(self, records): self._ensure_validity() + records = list(records) write_record = self._write_record for record in records: write_record(record) @@ -593,6 +595,7 @@ def _write_record(self, record): if fieldnames is None: self._fieldnames = fieldnames = list(record.keys()) + self._fieldnames.extend([i for i in self.custom_fields if i not in self._fieldnames]) value_list = imap(lambda fn: (str(fn), str('__mv_') + str(fn)), fieldnames) self._writerow(list(chain.from_iterable(value_list))) diff --git a/splunklib/searchcommands/search_command.py b/splunklib/searchcommands/search_command.py index 270569ad..5a626cc5 100644 --- a/splunklib/searchcommands/search_command.py +++ b/splunklib/searchcommands/search_command.py @@ -173,6 +173,14 @@ def logging_level(self, value): raise ValueError('Unrecognized logging level: {}'.format(value)) self._logger.setLevel(level) + def add_field(self, current_record, field_name, field_value): + self._record_writer.custom_fields.add(field_name) + current_record[field_name] = field_value + + def gen_record(self, **record): + self._record_writer.custom_fields |= set(record.keys()) + return record + record = Option(doc=''' **Syntax: record= diff --git a/tests/searchcommands/test_generator_command.py b/tests/searchcommands/test_generator_command.py index 3b2281e8..63ae3ac8 100644 --- a/tests/searchcommands/test_generator_command.py +++ b/tests/searchcommands/test_generator_command.py @@ -40,7 +40,6 @@ def generate(self): assert expected.issubset(seen) assert finished_seen - def test_allow_empty_input_for_generating_command(): """ Passing allow_empty_input for generating command will cause an error @@ -59,3 +58,29 @@ def generate(self): except ValueError as error: assert str(error) == "allow_empty_input cannot be False for Generating Commands" +def test_all_fieldnames_present_for_generated_records(): + @Configuration() + class GeneratorTest(GeneratingCommand): + def generate(self): + yield self.gen_record(_time=time.time(), one=1) + yield self.gen_record(_time=time.time(), two=2) + yield self.gen_record(_time=time.time(), three=3) + yield self.gen_record(_time=time.time(), four=4) + yield self.gen_record(_time=time.time(), five=5) + + generator = GeneratorTest() + in_stream = io.BytesIO() + in_stream.write(chunky.build_getinfo_chunk()) + in_stream.write(chunky.build_chunk({'action': 'execute'})) + in_stream.seek(0) + out_stream = io.BytesIO() + generator._process_protocol_v2([], in_stream, out_stream) + out_stream.seek(0) + + ds = chunky.ChunkedDataStream(out_stream) + fieldnames_expected = {'_time', 'one', 'two', 'three', 'four', 'five'} + fieldnames_actual = set() + for chunk in ds: + for row in chunk.data: + fieldnames_actual |= set(row.keys()) + assert fieldnames_expected.issubset(fieldnames_actual) diff --git a/tests/searchcommands/test_internals_v2.py b/tests/searchcommands/test_internals_v2.py index bdef65c4..34e6b61c 100755 --- a/tests/searchcommands/test_internals_v2.py +++ b/tests/searchcommands/test_internals_v2.py @@ -233,6 +233,8 @@ def test_record_writer_with_random_data(self, save_recording=False): self.assertGreater(writer._buffer.tell(), 0) self.assertEqual(writer._total_record_count, 0) self.assertEqual(writer.committed_record_count, 0) + fieldnames.sort() + writer._fieldnames.sort() self.assertListEqual(writer._fieldnames, fieldnames) self.assertListEqual(writer._inspector['messages'], messages) diff --git a/tests/searchcommands/test_streaming_command.py b/tests/searchcommands/test_streaming_command.py index dcc00b53..ffe6a737 100644 --- a/tests/searchcommands/test_streaming_command.py +++ b/tests/searchcommands/test_streaming_command.py @@ -27,3 +27,71 @@ def stream(self, records): output = chunky.ChunkedDataStream(ofile) getinfo_response = output.read_chunk() assert getinfo_response.meta["type"] == "streaming" + + +def test_field_preservation_negative(): + @Configuration() + class TestStreamingCommand(StreamingCommand): + + def stream(self, records): + for index, record in enumerate(records): + if index % 2 != 0: + record["odd_field"] = True + else: + record["even_field"] = True + yield record + + cmd = TestStreamingCommand() + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + data = list() + for i in range(0, 10): + data.append({"in_index": str(i)}) + ifile.write(chunky.build_data_chunk(data, finished=True)) + ifile.seek(0) + ofile = io.BytesIO() + cmd._process_protocol_v2([], ifile, ofile) + ofile.seek(0) + output_iter = chunky.ChunkedDataStream(ofile).__iter__() + output_iter.next() + output_records = [i for i in output_iter.next().data] + + # Assert that count of records having "odd_field" is 0 + assert len(list(filter(lambda r: "odd_field" in r, output_records))) == 0 + + # Assert that count of records having "even_field" is 10 + assert len(list(filter(lambda r: "even_field" in r, output_records))) == 10 + + +def test_field_preservation_positive(): + @Configuration() + class TestStreamingCommand(StreamingCommand): + + def stream(self, records): + for index, record in enumerate(records): + if index % 2 != 0: + self.add_field(record, "odd_field", True) + else: + self.add_field(record, "even_field", True) + yield record + + cmd = TestStreamingCommand() + ifile = io.BytesIO() + ifile.write(chunky.build_getinfo_chunk()) + data = list() + for i in range(0, 10): + data.append({"in_index": str(i)}) + ifile.write(chunky.build_data_chunk(data, finished=True)) + ifile.seek(0) + ofile = io.BytesIO() + cmd._process_protocol_v2([], ifile, ofile) + ofile.seek(0) + output_iter = chunky.ChunkedDataStream(ofile).__iter__() + output_iter.next() + output_records = [i for i in output_iter.next().data] + + # Assert that count of records having "odd_field" is 10 + assert len(list(filter(lambda r: "odd_field" in r, output_records))) == 10 + + # Assert that count of records having "even_field" is 10 + assert len(list(filter(lambda r: "even_field" in r, output_records))) == 10