diff --git a/docstring_to_markdown/rst.py b/docstring_to_markdown/rst.py index 4ca813f..40b42f0 100644 --- a/docstring_to_markdown/rst.py +++ b/docstring_to_markdown/rst.py @@ -198,18 +198,26 @@ class State(IntEnum): PARSING_ROWS = auto() FINISHED = auto() - outer_border_pattern = r'^(\s*)=+( +=+)+$' + outer_border_pattern: str + column_top_prefix: str + column_top_border: str + column_end_offset: int _state: int _column_starts: List[int] + _columns_end: int _columns: List[str] _rows: List[List[str]] _max_sizes: List[int] _indent: str + def __init__(self): + self._reset_state() + def _reset_state(self): self._state = TableParser.State.AWAITS self._column_starts = [] + self._columns_end = -1 self._columns = [] self._rows = [] self._max_sizes = [] @@ -222,11 +230,13 @@ def initiate_parsing(self, line: str, current_language: str) -> IBlockBeginning: self._reset_state() match = re.match(self.outer_border_pattern, line) assert match - self._indent = match.group(1) or '' + groups = match.groupdict() + self._indent = groups['indent'] or '' self._column_starts = [] - previous = ' ' + self._columns_end = match.end('column') + previous = self.column_top_prefix for i, char in enumerate(line): - if char == '=' and previous == ' ': + if char == self.column_top_border and previous == self.column_top_prefix: self._column_starts.append(i) previous = char self._max_sizes = [0 for i in self._column_starts] @@ -245,17 +255,24 @@ def consume(self, line: str) -> None: # TODO: check integrity? self._state += 1 elif self._state == states.PARSING_ROWS: - match = re.match(self.outer_border_pattern, line) - if match: - self._state += 1 - else: - self._rows.append(self._split(line)) + self._consume_row(line) + + def _consume_row(self, line: str): + match = re.match(self.outer_border_pattern, line) + if match: + self._state += 1 + else: + self._rows.append(self._split(line)) def _split(self, line: str) -> List[str]: assert self._column_starts fragments = [] for i, start in enumerate(self._column_starts): - end = self._column_starts[i + 1] if i < len(self._column_starts) - 1 else None + end = ( + self._column_starts[i + 1] + self.column_end_offset + if i < len(self._column_starts) - 1 else + self._columns_end + ) fragment = line[start:end].strip() self._max_sizes[i] = max(self._max_sizes[i], len(fragment)) fragments.append(fragment) @@ -281,6 +298,48 @@ def finish_consumption(self, final: bool) -> str: return result +class SimpleTableParser(TableParser): + outer_border_pattern = r'^(?P\s*)=+(?P +=+)+$' + column_top_prefix = ' ' + column_top_border = '=' + column_end_offset = 0 + + +class GridTableParser(TableParser): + outer_border_pattern = r'^(?P\s*)(?P\+-+)+\+$' + column_top_prefix = '+' + column_top_border = '-' + column_end_offset = -1 + + _expecting_row_content: bool + + def _reset_state(self): + super()._reset_state() + self._expecting_row_content = True + + def _is_correct_row(self, line: str) -> bool: + stripped = line.lstrip() + if self._expecting_row_content: + return stripped.startswith('|') + else: + return stripped.startswith('+-') + + def can_consume(self, line: str) -> bool: + return ( + bool(self._state != TableParser.State.FINISHED) + and + (self._state != TableParser.State.PARSING_ROWS or self._is_correct_row(line)) + ) + + def _consume_row(self, line: str): + if self._is_correct_row(line): + if self._expecting_row_content: + self._rows.append(self._split(line)) + self._expecting_row_content = not self._expecting_row_content + else: + self._state += 1 + + class BlockParser(IParser): enclosure = '```' follower: Union['IParser', None] = None @@ -445,7 +504,8 @@ def initiate_parsing(self, line: str, current_language: str) -> IBlockBeginning: MathBlockParser(), ExplicitCodeBlockParser(), DoubleColonBlockParser(), - TableParser() + SimpleTableParser(), + GridTableParser() ] RST_SECTIONS = { diff --git a/tests/test_rst.py b/tests/test_rst.py index 09ff5de..b6d171c 100644 --- a/tests/test_rst.py +++ b/tests/test_rst.py @@ -531,6 +531,43 @@ def func(): pass If True, then sub-classes will be passed-through, otherwise """ +GRID_TABLE_IN_SKLEARN = """ +Attributes +---------- +cv_results_ : dict of numpy (masked) ndarrays + A dict with keys as column headers and values as columns, that can be + imported into a pandas ``DataFrame``. + For instance the below given table + +------------+-----------+------------+-----------------+---+---------+ + |param_kernel|param_gamma|param_degree|split0_test_score|...|rank_t...| + +============+===========+============+=================+===+=========+ + | 'poly' | -- | 2 | 0.80 |...| 2 | + +------------+-----------+------------+-----------------+---+---------+ + | 'poly' | -- | 3 | 0.70 |...| 4 | + +------------+-----------+------------+-----------------+---+---------+ + | 'rbf' | 0.1 | -- | 0.80 |...| 3 | + +------------+-----------+------------+-----------------+---+---------+ + | 'rbf' | 0.2 | -- | 0.93 |...| 1 | + +------------+-----------+------------+-----------------+---+---------+ + will be represented by a ``cv_results_`` dict +""" + +GRID_TABLE_IN_SKLEARN_MARKDOWN = """ +#### Attributes + +- `cv_results_`: dict of numpy (masked) ndarrays + A dict with keys as column headers and values as columns, that can be + imported into a pandas ``DataFrame``. + For instance the below given table + | param_kernel | param_gamma | param_degree | split0_test_score | ... | rank_t... | + | ------------ | ----------- | ------------ | ----------------- | --- | --------- | + | 'poly' | -- | 2 | 0.80 | ... | 2 | + | 'poly' | -- | 3 | 0.70 | ... | 4 | + | 'rbf' | 0.1 | -- | 0.80 | ... | 3 | + | 'rbf' | 0.2 | -- | 0.93 | ... | 1 | + will be represented by a ``cv_results_`` dict +""" + INTEGRATION = """ Return a fixed frequency DatetimeIndex. @@ -661,6 +698,10 @@ def func(): pass 'converts indented simple table': { 'rst': SIMPLE_TABLE_IN_PARAMS, 'md': SIMPLE_TABLE_IN_PARAMS_MARKDOWN + }, + 'converts indented grid table': { + 'rst': GRID_TABLE_IN_SKLEARN, + 'md': GRID_TABLE_IN_SKLEARN_MARKDOWN } }