Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion cmdstanpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import sys
from collections import OrderedDict
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from io import StringIO
from multiprocessing import cpu_count
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union

import ujson as json
from tqdm.auto import tqdm
Expand All @@ -39,6 +40,7 @@
MaybeDictToFilePath,
SanitizedOrTmpFilePath,
cmdstan_path,
cmdstan_version,
cmdstan_version_before,
do_command,
get_logger,
Expand Down Expand Up @@ -297,6 +299,98 @@ def src_info(self) -> Dict[str, Any]:
get_logger().debug(e)
return result

def format(
self,
overwrite_file: bool = False,
canonicalize: Union[bool, str, Iterable[str]] = False,
max_line_length: int = 78,
*,
backup: bool = True,
) -> None:
"""
Run stanc's auto-formatter on the model code. Either saves directly
back to the file or prints for inspection


:param overwrite_file: If True, save the updated code to disk, rather
than printing it. By default False
:param canonicalize: Whether or not the compiler should 'canonicalize'
the Stan model, removing things like deprecated syntax. Default is
False. If True, all canonicalizations are run. If it is a list of
strings, those options are passed to stanc (new in Stan 2.29)
:param max_line_length: Set the wrapping point for the formatter. The
default value is 78, which wraps most lines by the 80th character.
:param backup: If True, create a stanfile.bak backup before
writing to the file. Only disable this if you're sure you have other
copies of the file or are using a version control system like Git.
"""
if self.stan_file is None or not os.path.isfile(self.stan_file):
raise ValueError("No Stan file found for this module")
try:
cmd = (
[os.path.join(cmdstan_path(), 'bin', 'stanc' + EXTENSION)]
# handle include-paths, allow-undefined etc
+ self._compiler_options.compose_stanc()
+ [self.stan_file]
)

if canonicalize:
if cmdstan_version_before(2, 29):
if isinstance(canonicalize, bool):
cmd.append('--print-canonical')
else:
raise ValueError(
"Invalid arguments passed for current CmdStan"
+ " version({})\n".format(
cmdstan_version() or "Unknown"
)
+ "--canonicalize requires 2.29 or higher"
)
else:
if isinstance(canonicalize, str):
cmd.append('--canonicalize=' + canonicalize)
elif isinstance(canonicalize, Iterable):
cmd.append('--canonicalize=' + ','.join(canonicalize))
else:
cmd.append('--print-canonical')

# before 2.29, having both --print-canonical
# and --auto-format printed twice
if not (cmdstan_version_before(2, 29) and canonicalize):
cmd.append('--auto-format')

if not cmdstan_version_before(2, 29):
cmd.append(f'--max-line-length={max_line_length}')
elif max_line_length != 78:
raise ValueError(
"Invalid arguments passed for current CmdStan version"
+ " ({})\n".format(cmdstan_version() or "Unknown")
+ "--max-line-length requires 2.29 or higher"
)

out = subprocess.run(
cmd, capture_output=True, text=True, check=True
)
if out.stderr:
get_logger().warning(out.stderr)
result = out.stdout
if overwrite_file:
if result:
if backup:
shutil.copyfile(
self.stan_file,
self.stan_file
+ '.bak-'
+ datetime.now().strftime("%Y%m%d%H%M%S"),
)
with (open(self.stan_file, 'w')) as file_handle:
file_handle.write(result)
else:
print(result)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Printing is not very pythonic. Should it just return the result?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking of how a user would use this method and I have a hard time imagining when they'd actually want the string as a result. If anything it encourages the style of programming we don't want, where the model code ends up stored as a string in the python code.

The workflow I'm imagining for this function is that it will either be used so that someone can preview the output, in which case printing is the right thing to do, or they will be batch-processing stan files and formatting them, in which case they'd be saving them back to disk.

I can be persuaded that this is too narrow, however

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think printing is ok in this case

except (ValueError, RuntimeError) as e:
raise RuntimeError("Stanc formatting failed") from e

@property
def stanc_options(self) -> Dict[str, Union[bool, int, str]]:
"""Options to stanc compilers."""
Expand Down
4 changes: 3 additions & 1 deletion test/data/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@
# and we re-ignore hpp and exe files
*.hpp
*.exe
!return_one.hpp
*.testbak
*.bak-*
!return_one.hpp
9 changes: 9 additions & 0 deletions test/data/format_me.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
generated quantities {
array[10,10,10,10,10] matrix[100,100] a_very_long_name;
int x = (((10)));
int y;
if (1)
y = 1;
else
y=2;
}
5 changes: 5 additions & 0 deletions test/data/format_me_deprecations.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# pound-sign comment
generated quantities {
int x;
x <- (((((3)))));
}
90 changes: 89 additions & 1 deletion test/test_model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
"""CmdStanModel tests"""

import contextlib
import io
import logging
import os
import shutil
import tempfile
import unittest
from glob import glob
from test import CustomTestCase
from unittest.mock import MagicMock, patch

import pytest
from testfixtures import LogCapture, StringComparison

from cmdstanpy.model import CmdStanModel
from cmdstanpy.utils import EXTENSION
from cmdstanpy.utils import EXTENSION, cmdstan_version_before

HERE = os.path.dirname(os.path.abspath(__file__))
DATAFILES_PATH = os.path.join(HERE, 'data')
Expand All @@ -34,6 +39,7 @@
BERN_BASENAME = 'bernoulli'


# pylint: disable=too-many-public-methods
class CmdStanModelTest(CustomTestCase):
def test_model_good(self):
# compile on instantiation, override model name
Expand Down Expand Up @@ -374,6 +380,88 @@ def test_model_includes_implicit(self):
model2 = CmdStanModel(stan_file=stan)
self.assertPathsEqual(model2.exe_file, exe)

@pytest.mark.skipif(
not cmdstan_version_before(2, 32),
reason="Deprecated syntax removed in Stan 2.32",
)
def test_model_format_deprecations(self):
stan = os.path.join(DATAFILES_PATH, 'format_me_deprecations.stan')

model = CmdStanModel(stan_file=stan, compile=False)

sys_stdout = io.StringIO()
with contextlib.redirect_stdout(sys_stdout):
model.format()

formatted = sys_stdout.getvalue()
self.assertIn("//", formatted)
self.assertNotIn("#", formatted)
self.assertEqual(formatted.count('('), 5)

sys_stdout = io.StringIO()
with contextlib.redirect_stdout(sys_stdout):
model.format(canonicalize=True)

formatted = sys_stdout.getvalue()
print(formatted)
self.assertNotIn("<-", formatted)
self.assertEqual(formatted.count('('), 0)

shutil.copy(stan, stan + '.testbak')
try:
model.format(overwrite_file=True, canonicalize=True)
self.assertEqual(len(glob(stan + '.bak-*')), 1)
finally:
shutil.copy(stan + '.testbak', stan)

@pytest.mark.skipif(
cmdstan_version_before(2, 29), reason='Options only available later'
)
def test_model_format_options(self):
stan = os.path.join(DATAFILES_PATH, 'format_me.stan')

model = CmdStanModel(stan_file=stan, compile=False)

sys_stdout = io.StringIO()
with contextlib.redirect_stdout(sys_stdout):
model.format(max_line_length=10)
formatted = sys_stdout.getvalue()
self.assertGreater(len(formatted.splitlines()), 11)

sys_stdout = io.StringIO()
with contextlib.redirect_stdout(sys_stdout):
model.format(canonicalize='braces')
formatted = sys_stdout.getvalue()
self.assertEqual(formatted.count('{'), 3)
self.assertEqual(formatted.count('('), 4)

sys_stdout = io.StringIO()
with contextlib.redirect_stdout(sys_stdout):
model.format(canonicalize=['parentheses'])
formatted = sys_stdout.getvalue()
self.assertEqual(formatted.count('{'), 1)
self.assertEqual(formatted.count('('), 1)

sys_stdout = io.StringIO()
with contextlib.redirect_stdout(sys_stdout):
model.format(canonicalize=True)
formatted = sys_stdout.getvalue()
self.assertEqual(formatted.count('{'), 3)
self.assertEqual(formatted.count('('), 1)

@patch('cmdstanpy.utils.cmdstan_version', MagicMock(return_value=(2, 27)))
def test_format_old_version(self):
self.assertTrue(cmdstan_version_before(2, 28))

stan = os.path.join(DATAFILES_PATH, 'format_me.stan')
model = CmdStanModel(stan_file=stan, compile=False)
with self.assertRaisesRegexNested(RuntimeError, r"--canonicalize"):
model.format(canonicalize='braces')
with self.assertRaisesRegexNested(RuntimeError, r"--max-line"):
model.format(max_line_length=88)

model.format(canonicalize=True)


if __name__ == '__main__':
unittest.main()