You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Add FP6 benchmark option to use BF16
* Change dequant bit-shifting logic for BF16
* Modify dequant + tensor core ops for bf16
* Template progress
* Modify fpx quant logic to include bf16
* Add tests for FP6 BF16
* Use type punning for large exponent multiplication
* Fix some TODOs
* Remove option to add exponent bias directly to the exponent bits
This approach is (much) slower than multiplying by 2^bias after the fact, so that's why it's not usable
* Reformat
* Cleanup
* Fix alignment
* Remove templated input type whenever possible
* Remove templated input type whenever possible 2
* Remove templated input type whenever possible 3
* Less hacky way to construct a float with a large exponent
* rtol=1e-2 instead of 1e-3 for bfloat16 test
* Guards for SM75
* Remove redundant `__CUDA_ARCH` guards in host code
Any check for `__CUDA_ARCH__` in `fp6_linear.cu` will always fail because `__CUDA_ARCH__` is undefined since all of the functions in `fp6_linear.cu` are host functions
* Fix consistency in checking for `CUDA_ARCH` versions
* Update docs
* Make float bias a constexpr
* Update docs more
* Fix SM75 support
* Compile guard for sm<75
* Check for CUDA synchronous errors after kernel launch
If this is not done, the kernel may still run but fail silently, leading to unexpected behavior
* Updated compile guard
* Fix problematic usage of `__CUDA_ARCH__`
There are currently several ways of using `__CUDA_ARCH__` that lead to undefined behavior. See https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-arch for details of how `__CUDA_ARCH__` should not be used
* Fix incorrect CUDA error handling
* Make the kernel fail for sm75 + bfloat16 inputs
This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 and W is in FP6 (E3M2 without infinities and NaN).
3
+
This kernel is adapted from https://github.com/usyd-fsalab/fp6_llm. It performs linear op (A @ W.T), where A is in FP16 or BF16 and W is in FP6 (E3M2 without infinities and NaN).
4
4
5
5
On most hardware, this kernel is faster than FP16 linear for batch size from 1 to 128, and slower for batch size larger than or equal to 256. See https://github.com/usyd-fsalab/fp6_llm/issues/8 for a detailed discussion.
6
6
7
-
See https://github.com/pytorch/ao/pull/223 for some benchmark results.
7
+
See https://github.com/pytorch/ao/pull/223and and https://github.com/pytorch/ao/pull/1147for some benchmark results.
0 commit comments