Skip to content

Commit 4d9ab27

Browse files
Replace TaggedFilesHierarchy with os.walk and implement configure_directory entrypoint (#695)
This PR adds a configure_directory entry point, as well as tests. It also removes TaggedFilesHierarchy and replaces it with os.walk. Finally, the Generator tests have been refactored. [ reviewed by @MattToast @mellis13 @juliaputko ] [ committed by @amandarichardsonn ]
1 parent 4faf95c commit 4d9ab27

File tree

9 files changed

+735
-621
lines changed

9 files changed

+735
-621
lines changed

smartsim/_core/commands/command_list.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@
3434
class CommandList(MutableSequence[Command]):
3535
"""Container for a Sequence of Command objects"""
3636

37-
def __init__(self, commands: t.Union[Command, t.List[Command]]):
37+
def __init__(self, commands: t.Optional[t.Union[Command, t.List[Command]]] = None):
3838
"""CommandList constructor"""
39-
if isinstance(commands, Command):
39+
if commands is None:
40+
commands = []
41+
elif isinstance(commands, Command):
4042
commands = [commands]
4143
self._commands: t.List[Command] = list(commands)
4244

smartsim/_core/entrypoints/file_operations.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def _abspath(input_path: str) -> pathlib.Path:
4949
"""Helper function to check that paths are absolute"""
5050
path = pathlib.Path(input_path)
5151
if not path.is_absolute():
52-
raise ValueError(f"path `{path}` must be absolute")
52+
raise ValueError(f"Path `{path}` must be absolute.")
5353
return path
5454

5555

@@ -62,6 +62,22 @@ def _make_substitution(
6262
)
6363

6464

65+
def _prepare_param_dict(param_dict: str) -> dict[str, t.Any]:
66+
"""Decode and deserialize a base64-encoded parameter dictionary.
67+
68+
This function takes a base64-encoded string representation of a dictionary,
69+
decodes it, and then deserializes it using pickle. It performs validation
70+
to ensure the resulting object is a non-empty dictionary.
71+
"""
72+
decoded_dict = base64.b64decode(param_dict)
73+
deserialized_dict = pickle.loads(decoded_dict)
74+
if not isinstance(deserialized_dict, dict):
75+
raise TypeError("param dict is not a valid dictionary")
76+
if not deserialized_dict:
77+
raise ValueError("param dictionary is empty")
78+
return deserialized_dict
79+
80+
6581
def _replace_tags_in(
6682
item: str,
6783
substitutions: t.Sequence[Callable[[str], str]],
@@ -70,6 +86,23 @@ def _replace_tags_in(
7086
return functools.reduce(lambda a, fn: fn(a), substitutions, item)
7187

7288

89+
def _process_file(
90+
substitutions: t.Sequence[Callable[[str], str]],
91+
source: pathlib.Path,
92+
destination: pathlib.Path,
93+
) -> None:
94+
"""
95+
Process a source file by replacing tags with specified substitutions and
96+
write the result to a destination file.
97+
"""
98+
# Set the lines to iterate over
99+
with open(source, "r+", encoding="utf-8") as file_stream:
100+
lines = [_replace_tags_in(line, substitutions) for line in file_stream]
101+
# write configured file to destination specified
102+
with open(destination, "w+", encoding="utf-8") as file_stream:
103+
file_stream.writelines(lines)
104+
105+
73106
def move(parsed_args: argparse.Namespace) -> None:
74107
"""Move a source file or directory to another location. If dest is an
75108
existing directory or a symlink to a directory, then the srouce will
@@ -155,9 +188,9 @@ def symlink(parsed_args: argparse.Namespace) -> None:
155188

156189
def configure(parsed_args: argparse.Namespace) -> None:
157190
"""Set, search and replace the tagged parameters for the
158-
configure operation within tagged files attached to an entity.
191+
configure_file operation within tagged files attached to an entity.
159192
160-
User-formatted files can be attached using the `configure` argument.
193+
User-formatted files can be attached using the `configure_file` argument.
161194
These files will be modified during ``Application`` generation to replace
162195
tagged sections in the user-formatted files with values from the `params`
163196
initializer argument used during ``Application`` creation:
@@ -166,39 +199,38 @@ def configure(parsed_args: argparse.Namespace) -> None:
166199
.. highlight:: bash
167200
.. code-block:: bash
168201
python -m smartsim._core.entrypoints.file_operations \
169-
configure /absolute/file/source/pat /absolute/file/dest/path \
202+
configure_file /absolute/file/source/path /absolute/file/dest/path \
170203
tag_deliminator param_dict
171204
172205
/absolute/file/source/path: The tagged files the search and replace operations
173206
to be performed upon
174207
/absolute/file/dest/path: The destination for configured files to be
175208
written to.
176-
tag_delimiter: tag for the configure operation to search for, defaults to
209+
tag_delimiter: tag for the configure_file operation to search for, defaults to
177210
semi-colon e.g. ";"
178211
param_dict: A dict of parameter names and values set for the file
179212
180213
"""
181214
tag_delimiter = parsed_args.tag_delimiter
182-
183-
decoded_dict = base64.b64decode(parsed_args.param_dict)
184-
param_dict = pickle.loads(decoded_dict)
185-
186-
if not param_dict:
187-
raise ValueError("param dictionary is empty")
188-
if not isinstance(param_dict, dict):
189-
raise TypeError("param dict is not a valid dictionary")
215+
param_dict = _prepare_param_dict(parsed_args.param_dict)
190216

191217
substitutions = tuple(
192218
_make_substitution(k, v, tag_delimiter) for k, v in param_dict.items()
193219
)
194-
195-
# Set the lines to iterate over
196-
with open(parsed_args.source, "r+", encoding="utf-8") as file_stream:
197-
lines = [_replace_tags_in(line, substitutions) for line in file_stream]
198-
199-
# write configured file to destination specified
200-
with open(parsed_args.dest, "w+", encoding="utf-8") as file_stream:
201-
file_stream.writelines(lines)
220+
if parsed_args.source.is_dir():
221+
for dirpath, _, filenames in os.walk(parsed_args.source):
222+
new_dir_dest = dirpath.replace(
223+
str(parsed_args.source), str(parsed_args.dest), 1
224+
)
225+
os.makedirs(new_dir_dest, exist_ok=True)
226+
for file_name in filenames:
227+
src_file = os.path.join(dirpath, file_name)
228+
dst_file = os.path.join(new_dir_dest, file_name)
229+
print(type(substitutions))
230+
_process_file(substitutions, src_file, dst_file)
231+
else:
232+
dst_file = parsed_args.dest / os.path.basename(parsed_args.source)
233+
_process_file(substitutions, parsed_args.source, dst_file)
202234

203235

204236
def get_parser() -> argparse.ArgumentParser:

0 commit comments

Comments
 (0)