Skip to content

Commit b121fab

Browse files
joshuadengfacebook-github-bot
authored andcommitted
Convert input/output in stats to be MB per iteration (#178)
Summary: Pull Request resolved: #178 Instead of naively showing input as just the pooling factor and output as just the embedding dimension this diff changes planner stats to use the actual size of input & output in terms of megabytes per iteration **input**: global_batch_size * pooling factor * sizeof(dtype of input) **output**: global_batch_size * (output size (1 in pooled)) * sizeof(dytpe of emb) * emb_dim This provides a sense of scale for data coming in and out, and additionally makes plans with multiple sharding types directly comparable. Also fixes a bug with TWCW, we incorrectly specified the ranks as entire world size when it should be limited to the local world of the host that the parameter is sharded on. Reviewed By: dstaay-fb Differential Revision: D35153224 fbshipit-source-id: c1e7d717ec0c1d074f7e059d843fba2d287eee56
1 parent bacdb06 commit b121fab

File tree

2 files changed

+80
-70
lines changed

2 files changed

+80
-70
lines changed

torchrec/distributed/planner/stats.py

Lines changed: 75 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
from collections import defaultdict
1010
from typing import Union, Tuple, Optional, Any, List, Dict, cast
1111

12+
from torchrec.distributed.planner.constants import BIGINT_DTYPE
1213
from torchrec.distributed.planner.types import (
1314
ShardingOption,
1415
Stats,
1516
Topology,
1617
ParameterConstraints,
1718
Storage,
1819
)
19-
from torchrec.distributed.planner.utils import bytes_to_gb
20+
from torchrec.distributed.planner.utils import bytes_to_gb, bytes_to_mb
2021
from torchrec.distributed.types import ShardingType, ParameterSharding, ShardingPlan
2122

2223

@@ -45,8 +46,7 @@ def log(
4546
Logs stats for a given sharding plan to stdout.
4647
4748
Provides a tabular view of stats for the given sharding plan with per device
48-
storage usage (HBM and DDR), perf, input (pooling factors), output (embedding
49-
dimension), and number and type of shards.
49+
storage usage (HBM and DDR), perf, input, output, and number/type of shards.
5050
5151
Args:
5252
sharding_plan (ShardingPlan): sharding plan chosen by the ShardingPlanner.
@@ -64,7 +64,7 @@ def log(
6464
for param_name, value in param_dict.items()
6565
}
6666
stats: Dict[int, Dict[str, Any]] = {
67-
rank: {"type": {}, "pooling_factor": 0.0, "embedding_dims": 0}
67+
rank: {"type": {}, "input_sizes": 0.0, "output_sizes": 0.0}
6868
for rank in range(topology.world_size)
6969
}
7070

@@ -78,7 +78,7 @@ def log(
7878
continue
7979
shard: ParameterSharding = shard_by_fqn[fqn]
8080

81-
ranks, pooling_factor, output_dims = self._get_shard_stats(
81+
ranks, input_sizes, output_sizes = self._get_shard_stats(
8282
shard=shard,
8383
sharding_option=sharding_option,
8484
world_size=topology.world_size,
@@ -92,8 +92,8 @@ def log(
9292
for i, rank in enumerate(ranks):
9393
count = stats[rank]["type"].get(sharding_type_abbr, 0)
9494
stats[rank]["type"][sharding_type_abbr] = count + 1
95-
stats[rank]["pooling_factor"] += pooling_factor[i]
96-
stats[rank]["embedding_dims"] += output_dims[i]
95+
stats[rank]["input_sizes"] += input_sizes[i]
96+
stats[rank]["output_sizes"] += output_sizes[i]
9797

9898
used_hbm = [0] * topology.world_size
9999
used_ddr = [0] * topology.world_size
@@ -107,14 +107,22 @@ def log(
107107
perf[rank] += cast(float, shard.perf)
108108

109109
table: List[List[Union[str, int]]] = [
110-
["Rank", "HBM (GB)", "DDR (GB)", "Perf", "Input", "Output", "Shards"],
110+
[
111+
"Rank",
112+
"HBM (GB)",
113+
"DDR (GB)",
114+
"Perf (ms)",
115+
"Input (MB)",
116+
"Output (MB)",
117+
"Shards",
118+
],
111119
[
112120
"------",
113121
"----------",
114122
"----------",
115-
"------",
116-
"-------",
117-
"--------",
123+
"-----------",
124+
"------------",
125+
"-------------",
118126
"--------",
119127
],
120128
]
@@ -135,8 +143,8 @@ def log(
135143
rank_hbm = f"{round(used_hbm_gb, 1)} ({used_hbm_ratio:.0%})"
136144
rank_ddr = f"{round(used_ddr_gb, 1)} ({used_ddr_ratio:.0%})"
137145
rank_perf = f"{round(perf[rank], 2)}"
138-
rank_pooling = f"{int(stats[rank]['pooling_factor']):,}"
139-
rank_dims = f"{stats[rank]['embedding_dims']:,}"
146+
rank_input = f"{round(stats[rank]['input_sizes'], 2)}"
147+
rank_output = f"{round(stats[rank]['output_sizes'], 2)}"
140148
rank_shards = " ".join(
141149
f"{sharding_type}: {num_tables}"
142150
for sharding_type, num_tables in sorted(stats[rank]["type"].items())
@@ -147,8 +155,8 @@ def log(
147155
rank_hbm,
148156
rank_ddr,
149157
rank_perf,
150-
rank_pooling,
151-
rank_dims,
158+
rank_input,
159+
rank_output,
152160
rank_shards,
153161
]
154162
)
@@ -157,12 +165,12 @@ def log(
157165

158166
if debug:
159167
param_table: List[List[Union[str, int]]] = [
160-
["FQN", "Sharding", "Compute Kernel", "Perf", "Ranks"],
168+
["FQN", "Sharding", "Compute Kernel", "Perf (ms)", "Ranks"],
161169
[
162170
"-----",
163171
"----------",
164172
"----------------",
165-
"------",
173+
"-----------",
166174
"-------",
167175
],
168176
]
@@ -203,8 +211,10 @@ def log(
203211
logger.info(f"# {row: <{width-3}}#")
204212

205213
logger.info(f"#{'' : ^{width-2}}#")
206-
legend = "Input: pooling factor, Output: output dim per sample, Shards: number of tables"
214+
legend = "Input: MB/iteration, Output: MB/iteration, Shards: number of tables"
207215
logger.info(f"# {legend: <{width-3}}#")
216+
hbm_info = "HBM: est. peak memory usage for shards - parameter, comms, optimizer, and gradients"
217+
logger.info(f"# {hbm_info: <{width-3}}#")
208218
logger.info(f"#{'' : ^{width-2}}#")
209219

210220
compute_kernels_count = [
@@ -231,70 +241,66 @@ def _get_shard_stats(
231241
world_size: int,
232242
local_size: int,
233243
constraints: Optional[Dict[str, ParameterConstraints]] = None,
234-
) -> Tuple[List[int], List[float], List[int]]:
244+
) -> Tuple[List[int], List[float], List[float]]:
235245
"""
236-
Gets ranks, pooling factors, and embedding dimensions per shard.
246+
Gets ranks, input sizes, and output sizes per shard.
247+
Input size is a function of pooling factor.
248+
Output size is a function of embedding dimension * number of features.
237249
238250
Returns:
239251
ranks: list of ranks.
240-
pooling_factor: list of pooling factors across ranks.
241-
output_dims: list of output dimensions across ranks.
252+
input_sizes: input size per iter in MB across ranks for given shard.
253+
output_sizes: output size per iter in MB across ranks for given shard.
242254
"""
255+
assert shard.ranks
256+
ranks = shard.ranks
243257

244-
ranks = list(range(world_size))
245-
pooling_factor = [
246-
sum(constraints[sharding_option.name].pooling_factors)
258+
batch_size = world_size * sharding_option.batch_size
259+
input_data_type_size = BIGINT_DTYPE
260+
pooling_factor = (
261+
float(sum(constraints[sharding_option.name].pooling_factors))
247262
if constraints and constraints.get(sharding_option.name)
248-
else 0.0
249-
]
250-
output_dims = [
251-
sharding_option.tensor.shape[1] * len(sharding_option.input_lengths)
252-
]
263+
else 1.0
264+
)
265+
num_features = len(sharding_option.input_lengths)
266+
output_data_type_size = sharding_option.tensor.element_size()
267+
num_outputs = 1 # for pooled embeddings
253268

254269
if shard.sharding_type == ShardingType.DATA_PARALLEL.value:
255-
output_dims = output_dims * len(ranks)
256-
pooling_factor = pooling_factor * len(ranks)
257-
258-
elif shard.sharding_type == ShardingType.TABLE_WISE.value:
259-
assert shard.ranks
260-
ranks = shard.ranks
261-
262-
elif shard.sharding_type == ShardingType.COLUMN_WISE.value:
263-
assert shard.ranks
264-
ranks = shard.ranks
265-
output_dims = [
266-
int(shard.shard_sizes[1])
267-
# pyre-ignore [16]
268-
for shard in shard.sharding_spec.shards
269-
]
270-
pooling_factor = pooling_factor * len(ranks)
271-
270+
batch_size = sharding_option.batch_size
272271
elif shard.sharding_type == ShardingType.ROW_WISE.value:
273-
pooling_factor = [pooling_factor[0] / world_size] * len(ranks)
274-
output_dims = output_dims * len(ranks)
275-
272+
pooling_factor /= world_size
276273
elif shard.sharding_type == ShardingType.TABLE_ROW_WISE.value:
277-
assert shard.ranks
278-
host_id = shard.ranks[0] // local_size
279-
ranks = list(range(host_id * local_size, (host_id + 1) * local_size))
280-
pooling_factor = [pooling_factor[0] / local_size] * len(ranks)
281-
output_dims = output_dims * len(ranks)
282-
283-
elif shard.sharding_type == ShardingType.TABLE_COLUMN_WISE.value:
284-
assert shard.ranks
285-
ranks = shard.ranks
286-
pooling_factor = pooling_factor * len(ranks)
287-
output_dims = [
288-
int(shard.shard_sizes[1] * len(sharding_option.input_lengths))
274+
pooling_factor /= local_size
275+
276+
input_sizes = [
277+
bytes_to_mb(batch_size * pooling_factor * input_data_type_size)
278+
] * len(ranks)
279+
output_sizes = (
280+
[
281+
bytes_to_mb(
282+
batch_size
283+
* num_outputs
284+
* sharding_option.tensor.shape[1] # embedding dim
285+
* num_features
286+
* output_data_type_size
287+
)
288+
]
289+
* len(ranks)
290+
if shard.sharding_type == ShardingType.DATA_PARALLEL.value
291+
else [
292+
bytes_to_mb(
293+
batch_size
294+
* num_outputs
295+
* int(shard.shard_sizes[1]) # embedding dim
296+
* num_features
297+
* output_data_type_size
298+
)
299+
# pyre-ignore [16]
289300
for shard in shard.sharding_spec.shards
290301
]
291-
292-
else:
293-
raise ValueError(
294-
f"Unrecognized or unsupported sharding type provided: {shard.sharding_type}"
295-
)
296-
297-
return ranks, pooling_factor, output_dims
302+
)
303+
return ranks, input_sizes, output_sizes
298304

299305

300306
def _get_sharding_type_abbr(sharding_type: str) -> str:

torchrec/distributed/planner/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import operator
99
from functools import reduce
10-
from typing import Iterable, Any, Type
10+
from typing import Union, Iterable, Any, Type
1111

1212
# pyre-ignore[2]
1313
def sharder_name(t: Type[Any]) -> str:
@@ -18,6 +18,10 @@ def bytes_to_gb(num_bytes: int) -> float:
1818
return float(num_bytes / (1024 * 1024 * 1024))
1919

2020

21+
def bytes_to_mb(num_bytes: Union[float, int]) -> float:
22+
return float(num_bytes / (1024 * 1024))
23+
24+
2125
def gb_to_bytes(gb: float) -> int:
2226
return int(gb * 1024 * 1024 * 1024)
2327

0 commit comments

Comments
 (0)