|
242 | 242 | <div class="pytorch-left-menu-search"> |
243 | 243 |
|
244 | 244 | <div class="version"> |
245 | | - <a href='https://pytorch.org/docs/versions.html'>main (2.3.0a0+git3ab0894 ) ▼</a> |
| 245 | + <a href='https://pytorch.org/docs/versions.html'>main (2.3.0a0+git663dd5d ) ▼</a> |
246 | 246 | </div> |
247 | 247 |
|
248 | 248 |
|
@@ -1494,24 +1494,24 @@ <h1>Source code for torch</h1><div class="highlight"><pre> |
1494 | 1494 | <span class="sd"> Supports three settings:</span> |
1495 | 1495 |
|
1496 | 1496 | <span class="sd"> * "highest", float32 matrix multiplications use the float32 datatype (24 mantissa</span> |
1497 | | -<span class="sd"> bits) for internal computations.</span> |
| 1497 | +<span class="sd"> bits with 23 bits explicitly stored) for internal computations.</span> |
1498 | 1498 | <span class="sd"> * "high", float32 matrix multiplications either use the TensorFloat32 datatype (10</span> |
1499 | | -<span class="sd"> mantissa bits) or treat each float32 number as the sum of two bfloat16 numbers</span> |
1500 | | -<span class="sd"> (approximately 16 mantissa bits), if the appropriate fast matrix multiplication</span> |
| 1499 | +<span class="sd"> mantissa bits explicitly stored) or treat each float32 number as the sum of two bfloat16 numbers</span> |
| 1500 | +<span class="sd"> (approximately 16 mantissa bits with 14 bits explicitly stored), if the appropriate fast matrix multiplication</span> |
1501 | 1501 | <span class="sd"> algorithms are available. Otherwise float32 matrix multiplications are computed</span> |
1502 | 1502 | <span class="sd"> as if the precision is "highest". See below for more information on the bfloat16</span> |
1503 | 1503 | <span class="sd"> approach.</span> |
1504 | 1504 | <span class="sd"> * "medium", float32 matrix multiplications use the bfloat16 datatype (8 mantissa</span> |
1505 | | -<span class="sd"> bits) for internal computations, if a fast matrix multiplication algorithm</span> |
| 1505 | +<span class="sd"> bits with 7 bits explicitly stored) for internal computations, if a fast matrix multiplication algorithm</span> |
1506 | 1506 | <span class="sd"> using that datatype internally is available. Otherwise float32</span> |
1507 | 1507 | <span class="sd"> matrix multiplications are computed as if the precision is "high".</span> |
1508 | 1508 |
|
1509 | 1509 | <span class="sd"> When using "high" precision, float32 multiplications may use a bfloat16-based algorithm</span> |
1510 | 1510 | <span class="sd"> that is more complicated than simply truncating to some smaller number mantissa bits</span> |
1511 | | -<span class="sd"> (e.g. 10 for TensorFloat32, 8 for bfloat16). Refer to [Henry2019]_ for a complete</span> |
| 1511 | +<span class="sd"> (e.g. 10 for TensorFloat32, 7 for bfloat16 explicitly stored). Refer to [Henry2019]_ for a complete</span> |
1512 | 1512 | <span class="sd"> description of this algorithm. To briefly explain here, the first step is to realize</span> |
1513 | 1513 | <span class="sd"> that we can perfectly encode a single float32 number as the sum of three bfloat16</span> |
1514 | | -<span class="sd"> numbers (because float32 has 24 mantissa bits while bfloat16 has 8, and both have the</span> |
| 1514 | +<span class="sd"> numbers (because float32 has 23 mantissa bits while bfloat16 has 7 explicitly stored, and both have the</span> |
1515 | 1515 | <span class="sd"> same number of exponent bits). This means that the product of two float32 numbers can</span> |
1516 | 1516 | <span class="sd"> be exactly given by the sum of nine products of bfloat16 numbers. We can then trade</span> |
1517 | 1517 | <span class="sd"> accuracy for speed by dropping some of these products. The "high" precision algorithm</span> |
|
0 commit comments