99from collections import defaultdict
1010from typing import Union , Tuple , Optional , Any , List , Dict , cast
1111
12+ from torchrec .distributed .planner .constants import BIGINT_DTYPE
1213from torchrec .distributed .planner .types import (
1314 ShardingOption ,
1415 Stats ,
1516 Topology ,
1617 ParameterConstraints ,
1718 Storage ,
1819)
19- from torchrec .distributed .planner .utils import bytes_to_gb
20+ from torchrec .distributed .planner .utils import bytes_to_gb , bytes_to_mb
2021from torchrec .distributed .types import ShardingType , ParameterSharding , ShardingPlan
2122
2223
@@ -45,8 +46,7 @@ def log(
4546 Logs stats for a given sharding plan to stdout.
4647
4748 Provides a tabular view of stats for the given sharding plan with per device
48- storage usage (HBM and DDR), perf, input (pooling factors), output (embedding
49- dimension), and number and type of shards.
49+ storage usage (HBM and DDR), perf, input, output, and number/type of shards.
5050
5151 Args:
5252 sharding_plan (ShardingPlan): sharding plan chosen by the ShardingPlanner.
@@ -64,7 +64,7 @@ def log(
6464 for param_name , value in param_dict .items ()
6565 }
6666 stats : Dict [int , Dict [str , Any ]] = {
67- rank : {"type" : {}, "pooling_factor " : 0.0 , "embedding_dims " : 0 }
67+ rank : {"type" : {}, "input_sizes " : 0.0 , "output_sizes " : 0. 0 }
6868 for rank in range (topology .world_size )
6969 }
7070
@@ -78,7 +78,7 @@ def log(
7878 continue
7979 shard : ParameterSharding = shard_by_fqn [fqn ]
8080
81- ranks , pooling_factor , output_dims = self ._get_shard_stats (
81+ ranks , input_sizes , output_sizes = self ._get_shard_stats (
8282 shard = shard ,
8383 sharding_option = sharding_option ,
8484 world_size = topology .world_size ,
@@ -92,8 +92,8 @@ def log(
9292 for i , rank in enumerate (ranks ):
9393 count = stats [rank ]["type" ].get (sharding_type_abbr , 0 )
9494 stats [rank ]["type" ][sharding_type_abbr ] = count + 1
95- stats [rank ]["pooling_factor " ] += pooling_factor [i ]
96- stats [rank ]["embedding_dims " ] += output_dims [i ]
95+ stats [rank ]["input_sizes " ] += input_sizes [i ]
96+ stats [rank ]["output_sizes " ] += output_sizes [i ]
9797
9898 used_hbm = [0 ] * topology .world_size
9999 used_ddr = [0 ] * topology .world_size
@@ -107,14 +107,22 @@ def log(
107107 perf [rank ] += cast (float , shard .perf )
108108
109109 table : List [List [Union [str , int ]]] = [
110- ["Rank" , "HBM (GB)" , "DDR (GB)" , "Perf" , "Input" , "Output" , "Shards" ],
110+ [
111+ "Rank" ,
112+ "HBM (GB)" ,
113+ "DDR (GB)" ,
114+ "Perf (ms)" ,
115+ "Input (MB)" ,
116+ "Output (MB)" ,
117+ "Shards" ,
118+ ],
111119 [
112120 "------" ,
113121 "----------" ,
114122 "----------" ,
115- "------" ,
116- "-------" ,
117- "--------" ,
123+ "----------- " ,
124+ "------------ " ,
125+ "------------- " ,
118126 "--------" ,
119127 ],
120128 ]
@@ -135,8 +143,8 @@ def log(
135143 rank_hbm = f"{ round (used_hbm_gb , 1 )} ({ used_hbm_ratio :.0%} )"
136144 rank_ddr = f"{ round (used_ddr_gb , 1 )} ({ used_ddr_ratio :.0%} )"
137145 rank_perf = f"{ round (perf [rank ], 2 )} "
138- rank_pooling = f"{ int (stats [rank ]['pooling_factor' ]):, } "
139- rank_dims = f"{ stats [rank ]['embedding_dims' ]:, } "
146+ rank_input = f"{ round (stats [rank ]['input_sizes' ], 2 ) } "
147+ rank_output = f"{ round ( stats [rank ]['output_sizes' ], 2 ) } "
140148 rank_shards = " " .join (
141149 f"{ sharding_type } : { num_tables } "
142150 for sharding_type , num_tables in sorted (stats [rank ]["type" ].items ())
@@ -147,8 +155,8 @@ def log(
147155 rank_hbm ,
148156 rank_ddr ,
149157 rank_perf ,
150- rank_pooling ,
151- rank_dims ,
158+ rank_input ,
159+ rank_output ,
152160 rank_shards ,
153161 ]
154162 )
@@ -157,12 +165,12 @@ def log(
157165
158166 if debug :
159167 param_table : List [List [Union [str , int ]]] = [
160- ["FQN" , "Sharding" , "Compute Kernel" , "Perf" , "Ranks" ],
168+ ["FQN" , "Sharding" , "Compute Kernel" , "Perf (ms) " , "Ranks" ],
161169 [
162170 "-----" ,
163171 "----------" ,
164172 "----------------" ,
165- "------" ,
173+ "----------- " ,
166174 "-------" ,
167175 ],
168176 ]
@@ -203,8 +211,10 @@ def log(
203211 logger .info (f"# { row : <{width - 3 }} #" )
204212
205213 logger .info (f"#{ '' : ^{width - 2 }} #" )
206- legend = "Input: pooling factor , Output: output dim per sample , Shards: number of tables"
214+ legend = "Input: MB/iteration , Output: MB/iteration , Shards: number of tables"
207215 logger .info (f"# { legend : <{width - 3 }} #" )
216+ hbm_info = "HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients"
217+ logger .info (f"# { hbm_info : <{width - 3 }} #" )
208218 logger .info (f"#{ '' : ^{width - 2 }} #" )
209219
210220 compute_kernels_count = [
@@ -231,70 +241,66 @@ def _get_shard_stats(
231241 world_size : int ,
232242 local_size : int ,
233243 constraints : Optional [Dict [str , ParameterConstraints ]] = None ,
234- ) -> Tuple [List [int ], List [float ], List [int ]]:
244+ ) -> Tuple [List [int ], List [float ], List [float ]]:
235245 """
236- Gets ranks, pooling factors, and embedding dimensions per shard.
246+ Gets ranks, input sizes, and output sizes per shard.
247+ Input size is a function of pooling factor.
248+ Output size is a function of embedding dimension * number of features.
237249
238250 Returns:
239251 ranks: list of ranks.
240- pooling_factor: list of pooling factors across ranks.
241- output_dims: list of output dimensions across ranks.
252+ input_sizes: input size per iter in MB across ranks for given shard .
253+ output_sizes: output size per iter in MB across ranks for given shard .
242254 """
255+ assert shard .ranks
256+ ranks = shard .ranks
243257
244- ranks = list (range (world_size ))
245- pooling_factor = [
246- sum (constraints [sharding_option .name ].pooling_factors )
258+ batch_size = world_size * sharding_option .batch_size
259+ input_data_type_size = BIGINT_DTYPE
260+ pooling_factor = (
261+ float (sum (constraints [sharding_option .name ].pooling_factors ))
247262 if constraints and constraints .get (sharding_option .name )
248- else 0 .0
249- ]
250- output_dims = [
251- sharding_option .tensor .shape [ 1 ] * len ( sharding_option . input_lengths )
252- ]
263+ else 1 .0
264+ )
265+ num_features = len ( sharding_option . input_lengths )
266+ output_data_type_size = sharding_option .tensor .element_size ( )
267+ num_outputs = 1 # for pooled embeddings
253268
254269 if shard .sharding_type == ShardingType .DATA_PARALLEL .value :
255- output_dims = output_dims * len (ranks )
256- pooling_factor = pooling_factor * len (ranks )
257-
258- elif shard .sharding_type == ShardingType .TABLE_WISE .value :
259- assert shard .ranks
260- ranks = shard .ranks
261-
262- elif shard .sharding_type == ShardingType .COLUMN_WISE .value :
263- assert shard .ranks
264- ranks = shard .ranks
265- output_dims = [
266- int (shard .shard_sizes [1 ])
267- # pyre-ignore [16]
268- for shard in shard .sharding_spec .shards
269- ]
270- pooling_factor = pooling_factor * len (ranks )
271-
270+ batch_size = sharding_option .batch_size
272271 elif shard .sharding_type == ShardingType .ROW_WISE .value :
273- pooling_factor = [pooling_factor [0 ] / world_size ] * len (ranks )
274- output_dims = output_dims * len (ranks )
275-
272+ pooling_factor /= world_size
276273 elif shard .sharding_type == ShardingType .TABLE_ROW_WISE .value :
277- assert shard .ranks
278- host_id = shard .ranks [0 ] // local_size
279- ranks = list (range (host_id * local_size , (host_id + 1 ) * local_size ))
280- pooling_factor = [pooling_factor [0 ] / local_size ] * len (ranks )
281- output_dims = output_dims * len (ranks )
282-
283- elif shard .sharding_type == ShardingType .TABLE_COLUMN_WISE .value :
284- assert shard .ranks
285- ranks = shard .ranks
286- pooling_factor = pooling_factor * len (ranks )
287- output_dims = [
288- int (shard .shard_sizes [1 ] * len (sharding_option .input_lengths ))
274+ pooling_factor /= local_size
275+
276+ input_sizes = [
277+ bytes_to_mb (batch_size * pooling_factor * input_data_type_size )
278+ ] * len (ranks )
279+ output_sizes = (
280+ [
281+ bytes_to_mb (
282+ batch_size
283+ * num_outputs
284+ * sharding_option .tensor .shape [1 ] # embedding dim
285+ * num_features
286+ * output_data_type_size
287+ )
288+ ]
289+ * len (ranks )
290+ if shard .sharding_type == ShardingType .DATA_PARALLEL .value
291+ else [
292+ bytes_to_mb (
293+ batch_size
294+ * num_outputs
295+ * int (shard .shard_sizes [1 ]) # embedding dim
296+ * num_features
297+ * output_data_type_size
298+ )
299+ # pyre-ignore [16]
289300 for shard in shard .sharding_spec .shards
290301 ]
291-
292- else :
293- raise ValueError (
294- f"Unrecognized or unsupported sharding type provided: { shard .sharding_type } "
295- )
296-
297- return ranks , pooling_factor , output_dims
302+ )
303+ return ranks , input_sizes , output_sizes
298304
299305
300306def _get_sharding_type_abbr (sharding_type : str ) -> str :
0 commit comments