-
Notifications
You must be signed in to change notification settings - Fork 649
- Reland D75563906 #4865
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
- Reland D75563906 #4865
Conversation
✅ Deploy Preview for pytorch-fbgemm-docs ready!
To edit notification comments on pull requests, go to your Netlify project configuration. |
This pull request was exported from Phabricator. Differential Revision: D77698275 |
a1d116e
to
2fb84a9
Compare
Summary: Reland D75563906 that was backed out, with fixes. The problem was the grid was not being big enough given the config. Further ensured vectorization, which allows 1.4Tb/s. Test: ``` Running correctness tests... Testing correctness for dtype: torch.float32 ✓ Passed: shape (32, 32) ✓ Passed: shape (64, 32) ✓ Passed: shape (256, 128) ✓ Passed: shape (512, 1024) ✓ Passed: shape (1024, 2048) ✓ Passed: shape (2048, 2048) ✓ Passed: shape (4096, 16384) ✓ Passed: shape (70000, 64) ✓ Passed: shape (131072, 512) ✓ Passed: shape (1000, 520) ✓ Passed: shape (4005, 4005) ✓ Passed: shape (10000, 1000) ✓ Passed: shape (1024, 10000) ✓ Passed: shape (8192, 4096) ✓ Passed: shape (10000, 10000) ✓ Passed: shape (3072, 10000) ✓ Passed: shape (6144, 10000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (512, 1536) ✓ Passed: shape (512, 6144) ✓ Passed: shape (512, 10240) ✓ Passed: shape (1000, 1000) ✓ Passed: shape (2000, 2000) ✓ Passed: shape (10240, 10240) ✓ Passed: shape (384, 128) ✓ Passed: shape (2048, 1024) ✓ Passed: shape (267, 513) ✓ Passed: shape (67, 123479) ✓ Passed: shape (1024, 123479) ✓ Passed: shape (1234154, 512) ✓ Passed: shape (2048, 66679) ✓ Passed: shape (200, 256) ✓ Passed: shape (1000, 256) ✓ Passed: shape (6000, 256) ✓ Passed: shape (6272, 256) ✓ Passed: shape (200, 512) ✓ Passed: shape (1000, 512) ✓ Passed: shape (6000, 512) ✓ Passed: shape (6272, 512) ✓ Passed: shape (200, 1024) ✓ Passed: shape (1000, 1024) ✓ Passed: shape (6000, 1024) ✓ Passed: shape (6272, 1024) ✓ Passed: shape (200, 2048) ✓ Passed: shape (1000, 2048) ✓ Passed: shape (6000, 2048) ✓ Passed: shape (6272, 2048) ✓ Passed: shape (200, 3072) ✓ Passed: shape (1000, 3072) ✓ Passed: shape (6000, 3072) ✓ Passed: shape (6272, 3072) ✓ Passed: shape (3000000, 512) Testing correctness for dtype: torch.float16 ✓ Passed: shape (32, 32) ✓ Passed: shape (64, 32) ✓ Passed: shape (256, 128) ✓ Passed: shape (512, 1024) ✓ Passed: shape (1024, 2048) ✓ Passed: shape (2048, 2048) ✓ Passed: shape (4096, 16384) ✓ Passed: shape (70000, 64) ✓ Passed: shape (131072, 512) ✓ Passed: shape (1000, 520) ✓ Passed: shape (4005, 4005) ✓ Passed: shape (10000, 1000) ✓ Passed: shape (1024, 10000) ✓ Passed: shape (8192, 4096) ✓ Passed: shape (10000, 10000) ✓ Passed: shape (3072, 10000) ✓ Passed: shape (6144, 10000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (512, 1536) ✓ Passed: shape (512, 6144) ✓ Passed: shape (512, 10240) ✓ Passed: shape (1000, 1000) ✓ Passed: shape (2000, 2000) ✓ Passed: shape (10240, 10240) ✓ Passed: shape (384, 128) ✓ Passed: shape (2048, 1024) ✓ Passed: shape (267, 513) ✓ Passed: shape (67, 123479) ✓ Passed: shape (1024, 123479) ✓ Passed: shape (1234154, 512) ✓ Passed: shape (2048, 66679) ✓ Passed: shape (200, 256) ✓ Passed: shape (1000, 256) ✓ Passed: shape (6000, 256) ✓ Passed: shape (6272, 256) ✓ Passed: shape (200, 512) ✓ Passed: shape (1000, 512) ✓ Passed: shape (6000, 512) ✓ Passed: shape (6272, 512) ✓ Passed: shape (200, 1024) ✓ Passed: shape (1000, 1024) ✓ Passed: shape (6000, 1024) ✓ Passed: shape (6272, 1024) ✓ Passed: shape (200, 2048) ✓ Passed: shape (1000, 2048) ✓ Passed: shape (6000, 2048) ✓ Passed: shape (6272, 2048) ✓ Passed: shape (200, 3072) ✓ Passed: shape (1000, 3072) ✓ Passed: shape (6000, 3072) ✓ Passed: shape (6272, 3072) ✓ Passed: shape (3000000, 512) Testing correctness for dtype: torch.bfloat16 ✓ Passed: shape (32, 32) ✓ Passed: shape (64, 32) ✓ Passed: shape (256, 128) ✓ Passed: shape (512, 1024) ✓ Passed: shape (1024, 2048) ✓ Passed: shape (2048, 2048) ✓ Passed: shape (4096, 16384) ✓ Passed: shape (70000, 64) ✓ Passed: shape (131072, 512) ✓ Passed: shape (1000, 520) ✓ Passed: shape (4005, 4005) ✓ Passed: shape (10000, 1000) ✓ Passed: shape (1024, 10000) ✓ Passed: shape (8192, 4096) ✓ Passed: shape (10000, 10000) ✓ Passed: shape (3072, 10000) ✓ Passed: shape (6144, 10000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (512, 1536) ✓ Passed: shape (512, 6144) ✓ Passed: shape (512, 10240) ✓ Passed: shape (1000, 1000) ✓ Passed: shape (2000, 2000) ✓ Passed: shape (10240, 10240) ✓ Passed: shape (384, 128) ✓ Passed: shape (2048, 1024) ✓ Passed: shape (267, 513) ✓ Passed: shape (67, 123479) ✓ Passed: shape (1024, 123479) ✓ Passed: shape (1234154, 512) ✓ Passed: shape (2048, 66679) ✓ Passed: shape (200, 256) ✓ Passed: shape (1000, 256) ✓ Passed: shape (6000, 256) ✓ Passed: shape (6272, 256) ✓ Passed: shape (200, 512) ✓ Passed: shape (1000, 512) ✓ Passed: shape (6000, 512) ✓ Passed: shape (6272, 512) ✓ Passed: shape (200, 1024) ✓ Passed: shape (1000, 1024) ✓ Passed: shape (6000, 1024) ✓ Passed: shape (6272, 1024) ✓ Passed: shape (200, 2048) ✓ Passed: shape (1000, 2048) ✓ Passed: shape (6000, 2048) ✓ Passed: shape (6272, 2048) ✓ Passed: shape (200, 3072) ✓ Passed: shape (1000, 3072) ✓ Passed: shape (6000, 3072) ✓ Passed: shape (6272, 3072) ✓ Passed: shape (3000000, 512) All correctness tests passed! 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:31<00:00, 1.68it/s] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:27<00:00, 1.90it/s] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:28<00:00, 1.86it/s] 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [01:27<00:00, 29.29s/it] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:31<00:00, 1.69it/s] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:31<00:00, 1.70it/s] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:30<00:00, 1.73it/s] 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [01:33<00:00, 31.02s/it] [------------------------------------ Not vectorized ------------------------------------] | fwd, torch.float32 | fwd, torch.float16 | fwd, torch.bfloat16 1 threads: ------------------------------------------------------------------------------- 32, 32 | 29.2 | 20.6 | 20.3 64, 32 | 29.3 | 20.5 | 21.2 256, 128 | 23.5 | 20.4 | 20.0 512, 1024 | 22.6 | 20.6 | 20.0 1024, 2048 | 52.9 | 40.9 | 40.9 2048, 2048 | 88.8 | 58.1 | 58.1 4096, 16384 | 1327.5 | 918.2 | 921.5 70000, 64 | 94.0 | 67.1 | 67.1 131072, 512 | 908.2 | 710.9 | 716.4 1000, 520 | 22.4 | 20.1 | 20.6 4005, 4005 | 339.3 | 237.7 | 238.7 10000, 1000 | 207.8 | 129.9 | 130.7 1024, 10000 | 357.8 | 272.0 | 277.6 8192, 4096 | 669.1 | 463.8 | 465.1 10000, 10000 | 1934.2 | 1283.3 | 1287.1 3072, 10000 | 676.7 | 484.2 | 484.9 6144, 10000 | 1165.3 | 767.3 | 770.7 1024, 20000 | 703.6 | 577.5 | 578.6 512, 1536 | 30.6 | 25.8 | 25.6 512, 6144 | 99.0 | 80.6 | 80.6 512, 10240 | 205.4 | 128.6 | 128.5 1000, 1000 | 30.2 | 24.0 | 23.8 2000, 2000 | 81.0 | 56.4 | 56.7 10240, 10240 | 1963.8 | 1323.1 | 1326.9 384, 128 | 20.7 | 20.5 | 20.0 2048, 1024 | 45.3 | 33.0 | 33.0 267, 513 | 20.7 | 20.1 | 19.8 67, 123479 | 2322.7 | 1230.8 | 1313.2 1024, 123479 | 4244.2 | 3485.9 | 3491.3 1234154, 512 | 6838.8 | 5890.0 | 5956.4 2048, 66679 | 3304.3 | 2477.3 | 2483.4 200, 256 | 20.5 | 19.7 | 19.7 1000, 256 | 20.5 | 19.9 | 19.6 6000, 256 | 32.9 | 23.7 | 23.8 6272, 256 | 34.0 | 24.8 | 24.9 200, 512 | 21.0 | 19.7 | 20.2 1000, 512 | 20.3 | 20.2 | 19.8 6000, 512 | 56.4 | 38.4 | 38.6 6272, 512 | 59.0 | 40.3 | 40.9 200, 1024 | 20.7 | 19.7 | 20.0 1000, 1024 | 30.2 | 24.1 | 24.0 6000, 1024 | 118.6 | 67.4 | 67.9 6272, 1024 | 123.8 | 73.3 | 72.6 200, 2048 | 28.9 | 26.3 | 26.2 1000, 2048 | 52.3 | 40.5 | 40.2 6000, 2048 | 238.6 | 159.3 | 159.2 6272, 2048 | 246.8 | 165.6 | 165.7 200, 3072 | 39.9 | 36.3 | 36.2 1000, 3072 | 75.0 | 56.8 | 56.5 6000, 3072 | 352.0 | 240.9 | 241.9 6272, 3072 | 365.4 | 249.1 | 250.1 3000000, 512 | 16525.9 | 14259.5 | 14365.3 Times are in microseconds (us). [-------------------------------------- Vectorized --------------------------------------] | fwd, torch.float32 | fwd, torch.float16 | fwd, torch.bfloat16 1 threads: ------------------------------------------------------------------------------- 32, 32 | 19.4 | 19.6 | 19.6 64, 32 | 19.6 | 19.5 | 20.1 256, 128 | 19.4 | 19.6 | 20.2 512, 1024 | 19.3 | 19.9 | 20.0 1024, 2048 | 30.2 | 29.4 | 29.3 2048, 2048 | 42.4 | 35.1 | 35.0 4096, 16384 | 613.5 | 562.1 | 564.2 70000, 64 | 50.2 | 50.1 | 50.2 131072, 512 | 548.1 | 467.3 | 470.2 1000, 520 | 19.7 | 20.3 | 19.8 4005, 4005 | 206.0 | 240.6 | 240.6 10000, 1000 | 98.3 | 75.4 | 75.1 1024, 10000 | 250.9 | 200.9 | 208.1 8192, 4096 | 317.8 | 285.0 | 286.0 10000, 10000 | 857.6 | 742.6 | 743.9 3072, 10000 | 347.1 | 320.1 | 320.0 6144, 10000 | 501.9 | 433.7 | 435.7 1024, 20000 | 486.9 | 471.7 | 471.3 512, 1536 | 21.9 | 21.7 | 21.6 512, 6144 | 64.8 | 62.7 | 62.9 512, 10240 | 139.6 | 100.2 | 100.2 1000, 1000 | 19.3 | 19.8 | 20.5 2000, 2000 | 36.5 | 34.3 | 34.3 10240, 10240 | 849.1 | 765.9 | 767.3 384, 128 | 20.0 | 20.0 | 20.1 2048, 1024 | 22.4 | 21.5 | 21.6 267, 513 | 20.6 | 20.2 | 19.7 67, 123479 | 2354.0 | 998.4 | 996.2 1024, 123479 | 3134.1 | 3418.3 | 3427.5 1234154, 512 | 5083.0 | 4251.9 | 4278.5 2048, 66679 | 2052.4 | 2323.9 | 2325.0 200, 256 | 19.7 | 19.7 | 19.9 1000, 256 | 20.0 | 19.5 | 19.6 6000, 256 | 20.2 | 19.4 | 19.7 6272, 256 | 19.9 | 19.7 | 20.2 200, 512 | 19.7 | 19.9 | 19.5 1000, 512 | 20.1 | 19.8 | 19.6 6000, 512 | 23.1 | 20.9 | 21.3 6272, 512 | 24.0 | 22.4 | 22.9 200, 1024 | 19.4 | 19.8 | 19.7 1000, 1024 | 19.7 | 19.5 | 20.1 6000, 1024 | 51.9 | 34.1 | 34.6 6272, 1024 | 55.5 | 39.0 | 37.9 200, 2048 | 24.5 | 24.4 | 24.4 1000, 2048 | 30.3 | 29.1 | 28.9 6000, 2048 | 109.2 | 96.5 | 96.9 6272, 2048 | 111.4 | 100.1 | 100.3 200, 3072 | 33.0 | 33.5 | 33.4 1000, 3072 | 41.4 | 39.5 | 39.5 6000, 3072 | 156.3 | 146.5 | 146.9 6272, 3072 | 160.9 | 150.2 | 151.0 3000000, 512 | 12385.8 | 10336.7 | 10366.0 Times are in microseconds (us). ``` Reviewed By: q10 Differential Revision: D77698275
@flaviotruzzi has exported this pull request. If you are a Meta employee, you can view the originating diff in D77698275. |
Summary: Reland D75563906 that was backed out, with fixes. The problem was the grid was not being big enough given the config. Further ensured vectorization, which allows 1.4Tb/s. Test: ``` Running correctness tests... Testing correctness for dtype: torch.float32 ✓ Passed: shape (32, 32) ✓ Passed: shape (64, 32) ✓ Passed: shape (256, 128) ✓ Passed: shape (512, 1024) ✓ Passed: shape (1024, 2048) ✓ Passed: shape (2048, 2048) ✓ Passed: shape (4096, 16384) ✓ Passed: shape (70000, 64) ✓ Passed: shape (131072, 512) ✓ Passed: shape (1000, 520) ✓ Passed: shape (4005, 4005) ✓ Passed: shape (10000, 1000) ✓ Passed: shape (1024, 10000) ✓ Passed: shape (8192, 4096) ✓ Passed: shape (10000, 10000) ✓ Passed: shape (3072, 10000) ✓ Passed: shape (6144, 10000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (512, 1536) ✓ Passed: shape (512, 6144) ✓ Passed: shape (512, 10240) ✓ Passed: shape (1000, 1000) ✓ Passed: shape (2000, 2000) ✓ Passed: shape (10240, 10240) ✓ Passed: shape (384, 128) ✓ Passed: shape (2048, 1024) ✓ Passed: shape (267, 513) ✓ Passed: shape (67, 123479) ✓ Passed: shape (1024, 123479) ✓ Passed: shape (1234154, 512) ✓ Passed: shape (2048, 66679) ✓ Passed: shape (200, 256) ✓ Passed: shape (1000, 256) ✓ Passed: shape (6000, 256) ✓ Passed: shape (6272, 256) ✓ Passed: shape (200, 512) ✓ Passed: shape (1000, 512) ✓ Passed: shape (6000, 512) ✓ Passed: shape (6272, 512) ✓ Passed: shape (200, 1024) ✓ Passed: shape (1000, 1024) ✓ Passed: shape (6000, 1024) ✓ Passed: shape (6272, 1024) ✓ Passed: shape (200, 2048) ✓ Passed: shape (1000, 2048) ✓ Passed: shape (6000, 2048) ✓ Passed: shape (6272, 2048) ✓ Passed: shape (200, 3072) ✓ Passed: shape (1000, 3072) ✓ Passed: shape (6000, 3072) ✓ Passed: shape (6272, 3072) ✓ Passed: shape (3000000, 512) Testing correctness for dtype: torch.float16 ✓ Passed: shape (32, 32) ✓ Passed: shape (64, 32) ✓ Passed: shape (256, 128) ✓ Passed: shape (512, 1024) ✓ Passed: shape (1024, 2048) ✓ Passed: shape (2048, 2048) ✓ Passed: shape (4096, 16384) ✓ Passed: shape (70000, 64) ✓ Passed: shape (131072, 512) ✓ Passed: shape (1000, 520) ✓ Passed: shape (4005, 4005) ✓ Passed: shape (10000, 1000) ✓ Passed: shape (1024, 10000) ✓ Passed: shape (8192, 4096) ✓ Passed: shape (10000, 10000) ✓ Passed: shape (3072, 10000) ✓ Passed: shape (6144, 10000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (512, 1536) ✓ Passed: shape (512, 6144) ✓ Passed: shape (512, 10240) ✓ Passed: shape (1000, 1000) ✓ Passed: shape (2000, 2000) ✓ Passed: shape (10240, 10240) ✓ Passed: shape (384, 128) ✓ Passed: shape (2048, 1024) ✓ Passed: shape (267, 513) ✓ Passed: shape (67, 123479) ✓ Passed: shape (1024, 123479) ✓ Passed: shape (1234154, 512) ✓ Passed: shape (2048, 66679) ✓ Passed: shape (200, 256) ✓ Passed: shape (1000, 256) ✓ Passed: shape (6000, 256) ✓ Passed: shape (6272, 256) ✓ Passed: shape (200, 512) ✓ Passed: shape (1000, 512) ✓ Passed: shape (6000, 512) ✓ Passed: shape (6272, 512) ✓ Passed: shape (200, 1024) ✓ Passed: shape (1000, 1024) ✓ Passed: shape (6000, 1024) ✓ Passed: shape (6272, 1024) ✓ Passed: shape (200, 2048) ✓ Passed: shape (1000, 2048) ✓ Passed: shape (6000, 2048) ✓ Passed: shape (6272, 2048) ✓ Passed: shape (200, 3072) ✓ Passed: shape (1000, 3072) ✓ Passed: shape (6000, 3072) ✓ Passed: shape (6272, 3072) ✓ Passed: shape (3000000, 512) Testing correctness for dtype: torch.bfloat16 ✓ Passed: shape (32, 32) ✓ Passed: shape (64, 32) ✓ Passed: shape (256, 128) ✓ Passed: shape (512, 1024) ✓ Passed: shape (1024, 2048) ✓ Passed: shape (2048, 2048) ✓ Passed: shape (4096, 16384) ✓ Passed: shape (70000, 64) ✓ Passed: shape (131072, 512) ✓ Passed: shape (1000, 520) ✓ Passed: shape (4005, 4005) ✓ Passed: shape (10000, 1000) ✓ Passed: shape (1024, 10000) ✓ Passed: shape (8192, 4096) ✓ Passed: shape (10000, 10000) ✓ Passed: shape (3072, 10000) ✓ Passed: shape (6144, 10000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (1024, 20000) ✓ Passed: shape (512, 1536) ✓ Passed: shape (512, 6144) ✓ Passed: shape (512, 10240) ✓ Passed: shape (1000, 1000) ✓ Passed: shape (2000, 2000) ✓ Passed: shape (10240, 10240) ✓ Passed: shape (384, 128) ✓ Passed: shape (2048, 1024) ✓ Passed: shape (267, 513) ✓ Passed: shape (67, 123479) ✓ Passed: shape (1024, 123479) ✓ Passed: shape (1234154, 512) ✓ Passed: shape (2048, 66679) ✓ Passed: shape (200, 256) ✓ Passed: shape (1000, 256) ✓ Passed: shape (6000, 256) ✓ Passed: shape (6272, 256) ✓ Passed: shape (200, 512) ✓ Passed: shape (1000, 512) ✓ Passed: shape (6000, 512) ✓ Passed: shape (6272, 512) ✓ Passed: shape (200, 1024) ✓ Passed: shape (1000, 1024) ✓ Passed: shape (6000, 1024) ✓ Passed: shape (6272, 1024) ✓ Passed: shape (200, 2048) ✓ Passed: shape (1000, 2048) ✓ Passed: shape (6000, 2048) ✓ Passed: shape (6272, 2048) ✓ Passed: shape (200, 3072) ✓ Passed: shape (1000, 3072) ✓ Passed: shape (6000, 3072) ✓ Passed: shape (6272, 3072) ✓ Passed: shape (3000000, 512) All correctness tests passed! 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:31<00:00, 1.68it/s] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:27<00:00, 1.90it/s] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:28<00:00, 1.86it/s] 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [01:27<00:00, 29.29s/it] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:31<00:00, 1.69it/s] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:31<00:00, 1.70it/s] 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 53/53 [00:30<00:00, 1.73it/s] 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [01:33<00:00, 31.02s/it] [------------------------------------ Not vectorized ------------------------------------] | fwd, torch.float32 | fwd, torch.float16 | fwd, torch.bfloat16 1 threads: ------------------------------------------------------------------------------- 32, 32 | 29.2 | 20.6 | 20.3 64, 32 | 29.3 | 20.5 | 21.2 256, 128 | 23.5 | 20.4 | 20.0 512, 1024 | 22.6 | 20.6 | 20.0 1024, 2048 | 52.9 | 40.9 | 40.9 2048, 2048 | 88.8 | 58.1 | 58.1 4096, 16384 | 1327.5 | 918.2 | 921.5 70000, 64 | 94.0 | 67.1 | 67.1 131072, 512 | 908.2 | 710.9 | 716.4 1000, 520 | 22.4 | 20.1 | 20.6 4005, 4005 | 339.3 | 237.7 | 238.7 10000, 1000 | 207.8 | 129.9 | 130.7 1024, 10000 | 357.8 | 272.0 | 277.6 8192, 4096 | 669.1 | 463.8 | 465.1 10000, 10000 | 1934.2 | 1283.3 | 1287.1 3072, 10000 | 676.7 | 484.2 | 484.9 6144, 10000 | 1165.3 | 767.3 | 770.7 1024, 20000 | 703.6 | 577.5 | 578.6 512, 1536 | 30.6 | 25.8 | 25.6 512, 6144 | 99.0 | 80.6 | 80.6 512, 10240 | 205.4 | 128.6 | 128.5 1000, 1000 | 30.2 | 24.0 | 23.8 2000, 2000 | 81.0 | 56.4 | 56.7 10240, 10240 | 1963.8 | 1323.1 | 1326.9 384, 128 | 20.7 | 20.5 | 20.0 2048, 1024 | 45.3 | 33.0 | 33.0 267, 513 | 20.7 | 20.1 | 19.8 67, 123479 | 2322.7 | 1230.8 | 1313.2 1024, 123479 | 4244.2 | 3485.9 | 3491.3 1234154, 512 | 6838.8 | 5890.0 | 5956.4 2048, 66679 | 3304.3 | 2477.3 | 2483.4 200, 256 | 20.5 | 19.7 | 19.7 1000, 256 | 20.5 | 19.9 | 19.6 6000, 256 | 32.9 | 23.7 | 23.8 6272, 256 | 34.0 | 24.8 | 24.9 200, 512 | 21.0 | 19.7 | 20.2 1000, 512 | 20.3 | 20.2 | 19.8 6000, 512 | 56.4 | 38.4 | 38.6 6272, 512 | 59.0 | 40.3 | 40.9 200, 1024 | 20.7 | 19.7 | 20.0 1000, 1024 | 30.2 | 24.1 | 24.0 6000, 1024 | 118.6 | 67.4 | 67.9 6272, 1024 | 123.8 | 73.3 | 72.6 200, 2048 | 28.9 | 26.3 | 26.2 1000, 2048 | 52.3 | 40.5 | 40.2 6000, 2048 | 238.6 | 159.3 | 159.2 6272, 2048 | 246.8 | 165.6 | 165.7 200, 3072 | 39.9 | 36.3 | 36.2 1000, 3072 | 75.0 | 56.8 | 56.5 6000, 3072 | 352.0 | 240.9 | 241.9 6272, 3072 | 365.4 | 249.1 | 250.1 3000000, 512 | 16525.9 | 14259.5 | 14365.3 Times are in microseconds (us). [-------------------------------------- Vectorized --------------------------------------] | fwd, torch.float32 | fwd, torch.float16 | fwd, torch.bfloat16 1 threads: ------------------------------------------------------------------------------- 32, 32 | 19.4 | 19.6 | 19.6 64, 32 | 19.6 | 19.5 | 20.1 256, 128 | 19.4 | 19.6 | 20.2 512, 1024 | 19.3 | 19.9 | 20.0 1024, 2048 | 30.2 | 29.4 | 29.3 2048, 2048 | 42.4 | 35.1 | 35.0 4096, 16384 | 613.5 | 562.1 | 564.2 70000, 64 | 50.2 | 50.1 | 50.2 131072, 512 | 548.1 | 467.3 | 470.2 1000, 520 | 19.7 | 20.3 | 19.8 4005, 4005 | 206.0 | 240.6 | 240.6 10000, 1000 | 98.3 | 75.4 | 75.1 1024, 10000 | 250.9 | 200.9 | 208.1 8192, 4096 | 317.8 | 285.0 | 286.0 10000, 10000 | 857.6 | 742.6 | 743.9 3072, 10000 | 347.1 | 320.1 | 320.0 6144, 10000 | 501.9 | 433.7 | 435.7 1024, 20000 | 486.9 | 471.7 | 471.3 512, 1536 | 21.9 | 21.7 | 21.6 512, 6144 | 64.8 | 62.7 | 62.9 512, 10240 | 139.6 | 100.2 | 100.2 1000, 1000 | 19.3 | 19.8 | 20.5 2000, 2000 | 36.5 | 34.3 | 34.3 10240, 10240 | 849.1 | 765.9 | 767.3 384, 128 | 20.0 | 20.0 | 20.1 2048, 1024 | 22.4 | 21.5 | 21.6 267, 513 | 20.6 | 20.2 | 19.7 67, 123479 | 2354.0 | 998.4 | 996.2 1024, 123479 | 3134.1 | 3418.3 | 3427.5 1234154, 512 | 5083.0 | 4251.9 | 4278.5 2048, 66679 | 2052.4 | 2323.9 | 2325.0 200, 256 | 19.7 | 19.7 | 19.9 1000, 256 | 20.0 | 19.5 | 19.6 6000, 256 | 20.2 | 19.4 | 19.7 6272, 256 | 19.9 | 19.7 | 20.2 200, 512 | 19.7 | 19.9 | 19.5 1000, 512 | 20.1 | 19.8 | 19.6 6000, 512 | 23.1 | 20.9 | 21.3 6272, 512 | 24.0 | 22.4 | 22.9 200, 1024 | 19.4 | 19.8 | 19.7 1000, 1024 | 19.7 | 19.5 | 20.1 6000, 1024 | 51.9 | 34.1 | 34.6 6272, 1024 | 55.5 | 39.0 | 37.9 200, 2048 | 24.5 | 24.4 | 24.4 1000, 2048 | 30.3 | 29.1 | 28.9 6000, 2048 | 109.2 | 96.5 | 96.9 6272, 2048 | 111.4 | 100.1 | 100.3 200, 3072 | 33.0 | 33.5 | 33.4 1000, 3072 | 41.4 | 39.5 | 39.5 6000, 3072 | 156.3 | 146.5 | 146.9 6272, 3072 | 160.9 | 150.2 | 151.0 3000000, 512 | 12385.8 | 10336.7 | 10366.0 Times are in microseconds (us). ``` Reviewed By: q10 Differential Revision: D77698275
2fb84a9
to
f64a51e
Compare
@flaviotruzzi has exported this pull request. If you are a Meta employee, you can view the originating diff in D77698275. |
Summary:
Reland D75563906 that was backed out, with fixes.
The problem was the grid was not being big enough given the config.
Further ensured vectorization, which allows 1.4Tb/s.
Test:
Reviewed By: q10
Differential Revision: D77698275