@@ -59,26 +59,11 @@ def _update_history_stack(
5959 amax_history_stack .copy_ (new_amax_history_stack )
6060
6161
62- def filter_out_small_unaligned_layers (size_limit : int ) -> Callable [[nn .Linear ], bool ]:
63- """
64- Returns a callable that filters out small (dimensions less than the given `size_limit`)
65- and unaligned (dimenstions not divisible by 16) layers.
66- It can be passed as the `linear_layer_filter` argument to `swap_linear_with_float8_linear`.
67- """
68- return (
69- lambda linear_layer : linear_layer .in_features >= size_limit
70- and linear_layer .out_features >= size_limit
71- and linear_layer .in_features % 16 == 0
72- and linear_layer .out_features % 16 == 0
73- )
74-
75-
7662def swap_linear_layers (
7763 module : nn .Module ,
7864 from_float_func : Callable [[nn .Linear ], nn .Linear ],
7965 * ,
80- skip_fqn_list : Optional [List [str ]] = None ,
81- linear_layer_filter : Optional [Callable [[nn .Linear ], bool ]] = None ,
66+ layer_filter_fn : Optional [Callable [[str , nn .Module ], bool ]] = None ,
8267) -> Optional [nn .Module ]:
8368 """
8469 Generic function to swap linear layers in a module with a new type of linear layer.
@@ -90,18 +75,17 @@ def swap_linear_layers(
9075 Args:
9176 module: Module to modify.
9277 from_float_func: Function that accepts a linear layer and returns a new type of linear layer.
93- skip_fqn_list: If specified, a list of module FQNs to skip.
94- linear_layer_filter: If specified, only the linear layers
95- that pass the filter function will be swapped.
96- from_float_kwargs: Additional keyword arguments for from_float_func.
78+ layer_filter_fn: If specified, only the modules that
79+ that pass the filter function will be swapped. The inputs to the
80+ filter function are the FQN and module instance.
9781
9882 Returns:
9983 nn.Module: The modified module with swapped linear layers.
10084 """
101- module_names_to_skip = set (skip_fqn_list or [])
102-
10385 if isinstance (module , nn .Linear ) and (
104- linear_layer_filter is None or linear_layer_filter (module )
86+ # linear_layer_filter is None or linear_layer_filter(module)
87+ layer_filter_fn is None
88+ or layer_filter_fn ("" , module )
10589 ):
10690 if len (list (module .children ())) > 0 :
10791 raise AssertionError (
@@ -112,43 +96,44 @@ def swap_linear_layers(
11296 )
11397
11498 root_module = module
115- visited_modules = {root_module }
116-
117- for module_name , module in root_module .named_modules ():
118- if module_name in module_names_to_skip :
119- visited_modules .add (module )
12099
121100 def post_order_traversal (
122- module : nn .Module , module_name : str , parent_module : Optional [nn .Module ]
101+ module : nn .Module ,
102+ cur_fqn : Optional [str ] = None ,
103+ parent_module : Optional [nn .Module ] = None ,
123104 ):
124- nonlocal visited_modules
105+ if cur_fqn is None :
106+ cur_fqn = ""
107+
125108 for child_module_name , child_module in module .named_children ():
126- if child_module not in visited_modules :
127- visited_modules .add (child_module )
128- post_order_traversal (child_module , child_module_name , module )
109+ if cur_fqn == "" :
110+ new_fqn = child_module_name
111+ else :
112+ new_fqn = f"{ cur_fqn } .{ child_module_name } "
113+
114+ post_order_traversal (child_module , new_fqn , module )
129115
130116 if isinstance (module , nn .Linear ) and (
131- linear_layer_filter is None or linear_layer_filter (module )
117+ # linear_layer_filter is None or linear_layer_filter(module)
118+ layer_filter_fn is None
119+ or layer_filter_fn (cur_fqn , module )
132120 ):
133121 assert (
134122 parent_module is not None
135123 ), f"Linear root module should return early: { module } "
136124 new_linear_module = from_float_func (module )
137- setattr (parent_module , module_name , new_linear_module )
125+ cur_module_name = cur_fqn .split ("." )[- 1 ]
126+ setattr (parent_module , cur_module_name , new_linear_module )
138127
139- post_order_traversal (root_module , "" , None )
140- # Without this explicit `del`, this set only gets deleted upon an explicit
141- # garbage collection (not from when its refcount hits zero)
142- del visited_modules
128+ post_order_traversal (root_module )
143129 return root_module
144130
145131
146132def swap_linear_with_float8_linear (
147133 module : nn .Module ,
148134 * ,
149- skip_fqn_list : Optional [List [str ]] = None ,
150135 emulate : bool = False ,
151- linear_layer_filter : Optional [Callable [[nn .Linear ], bool ]] = None ,
136+ layer_filter_fn : Optional [Callable [[str , nn .Module ], bool ]] = None ,
152137 scaling_type_x : TensorScalingType = TensorScalingType .DYNAMIC ,
153138 scaling_type_w : TensorScalingType = TensorScalingType .DYNAMIC ,
154139 scaling_type_dL_dY : TensorScalingType = TensorScalingType .DYNAMIC ,
@@ -158,10 +143,10 @@ def swap_linear_with_float8_linear(
158143
159144 Args:
160145 module: Module to modify.
161- skip_fqn_list: If specified, a list of module FQNs to skip.
162146 emulate: If True, emulation is used instead of hardware accelerated gemm
163- linear_layer_filter: If specified, only the linear layers
164- that pass the filter function will be swapped.
147+ layer_filter_fn: If specified, only the modules that
148+ that pass the filter function will be swapped. The inputs to the
149+ filter function are the FQN and module instance.
165150 scaling_type_x (TensorScalingType): scaling type for `x`
166151 scaling_type_w (TensorScalingType): scaling type for `w`
167152 scaling_type_dL_dY (TensorScalingType): scaling type for `dL_dY`
@@ -179,8 +164,7 @@ def swap_linear_with_float8_linear(
179164 return swap_linear_layers (
180165 module ,
181166 from_float ,
182- skip_fqn_list = skip_fqn_list ,
183- linear_layer_filter = linear_layer_filter ,
167+ layer_filter_fn = layer_filter_fn ,
184168 )
185169
186170
0 commit comments