6
6
from autoPyTorch .pipeline .components .base_component import autoPyTorchComponent
7
7
8
8
9
- class HyperparameterSearchSpaceUpdate ():
9
+ class HyperparameterSearchSpaceUpdate :
10
+ """
11
+ Allows specifying update to the search space of a
12
+ particular hyperparameter.
13
+
14
+ Args:
15
+ node_name (str):
16
+ The name of the node in the pipeline
17
+ hyperparameter (str):
18
+ The name of the hyperparameter
19
+ value_range (Union[List, Tuple]):
20
+ In case of categorical hyperparameter, defines the new categorical choices.
21
+ In case of numerical hyperparameter, defines the new range
22
+ in the form of (LOWER, UPPER)
23
+ default_value (Union[int, float, str]):
24
+ New default value for the hyperparameter
25
+ log (bool) (default=False):
26
+ In case of numerical hyperparameters, whether to sample on a log scale
27
+ """
10
28
def __init__ (self , node_name : str , hyperparameter : str , value_range : Union [List , Tuple ],
11
29
default_value : Union [int , float , str ], log : bool = False ) -> None :
12
30
self .node_name = node_name
@@ -16,6 +34,15 @@ def __init__(self, node_name: str, hyperparameter: str, value_range: Union[List,
16
34
self .default_value = default_value
17
35
18
36
def apply (self , pipeline : List [Tuple [str , Union [autoPyTorchComponent , autoPyTorchChoice ]]]) -> None :
37
+ """
38
+ Applies the update to the appropriate hyperparameter of the pipeline
39
+ Args:
40
+ pipeline (List[Tuple[str, Union[autoPyTorchComponent, autoPyTorchChoice]]]):
41
+ The named steps of the current autopytorch pipeline
42
+
43
+ Returns:
44
+ None
45
+ """
19
46
[node [1 ]._apply_search_space_update (name = self .hyperparameter ,
20
47
new_value_range = self .value_range ,
21
48
log = self .log ,
@@ -29,30 +56,69 @@ def __str__(self) -> str:
29
56
(" log" if self .log else "" ))
30
57
31
58
32
- class HyperparameterSearchSpaceUpdates ():
59
+ class HyperparameterSearchSpaceUpdates :
60
+ """ Contains a collection of HyperparameterSearchSpaceUpdate """
33
61
def __init__ (self , updates : Optional [List [HyperparameterSearchSpaceUpdate ]] = None ) -> None :
34
62
self .updates = updates if updates is not None else []
35
63
36
64
def apply (self , pipeline : List [Tuple [str , Union [autoPyTorchComponent , autoPyTorchChoice ]]]) -> None :
65
+ """
66
+ Iteratively applies updates to the pipeline
67
+
68
+ Args:
69
+ pipeline: (List[Tuple[str, Union[autoPyTorchComponent, autoPyTorchChoice]]]):
70
+ The named steps of the current autoPyTorch pipeline
71
+
72
+ Returns:
73
+ None
74
+ """
37
75
for update in self .updates :
38
76
update .apply (pipeline )
39
77
40
78
def append (self , node_name : str , hyperparameter : str , value_range : Union [List , Tuple ],
41
79
default_value : Union [int , float , str ], log : bool = False ) -> None :
80
+ """
81
+ Add a new update
82
+
83
+ Args:
84
+ node_name (str):
85
+ The name of the node in the pipeline
86
+ hyperparameter (str):
87
+ The name of the hyperparameter
88
+ value_range (Union[List, Tuple]):
89
+ In case of categorical hyperparameter, defines the new categorical choices.
90
+ In case of numerical hyperparameter, defines the new range
91
+ in the form of (LOWER, UPPER)
92
+ default_value (Union[int, float, str]):
93
+ New default value for the hyperparameter
94
+ log (bool) (default=False):
95
+ In case of numerical hyperparameters, whether to sample on a log scale
96
+
97
+ Returns:
98
+ None
99
+ """
42
100
self .updates .append (HyperparameterSearchSpaceUpdate (node_name = node_name ,
43
101
hyperparameter = hyperparameter ,
44
102
value_range = value_range ,
45
103
default_value = default_value ,
46
104
log = log ))
47
105
48
106
def save_as_file (self , path : str ) -> None :
107
+ """
108
+ Save the updates as a file to reuse later
109
+
110
+ Args:
111
+ path (str): path of the file
112
+
113
+ Returns:
114
+ None
115
+ """
49
116
with open (path , "w" ) as f :
50
- with open (path , "w" ) as f :
51
- for update in self .updates :
52
- print (update .node_name , update .hyperparameter , # noqa: T001
53
- str (update .value_range ), "'{}'" .format (update .default_value )
54
- if isinstance (update .default_value , str ) else update .default_value ,
55
- (" log" if update .log else "" ), file = f )
117
+ for update in self .updates :
118
+ print (update .node_name , update .hyperparameter , # noqa: T001
119
+ str (update .value_range ), "'{}'" .format (update .default_value )
120
+ if isinstance (update .default_value , str ) else update .default_value ,
121
+ (" log" if update .log else "" ), file = f )
56
122
57
123
58
124
def parse_hyperparameter_search_space_updates (updates_file : Optional [str ]
0 commit comments