Skip to content
This repository was archived by the owner on Nov 1, 2024. It is now read-only.

Commit 21a5c8e

Browse files
authored
Add unittests for AnyPrecisionOptimizer (#62)
* AnyPrecision unittests
1 parent d611c4c commit 21a5c8e

File tree

3 files changed

+86
-1
lines changed

3 files changed

+86
-1
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .anyprecision_optimizer import AnyPrecisionAdamW

src/python/torchdistx/optimizers/anyprecision_optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ def __init__(
4646
momentum_dtype = dtype for momentum (default: BFloat32)
4747
variance_dtype = dtype for uncentered variance (default: BFloat16)
4848
compensation_buffer_dtype = dtype for Kahan summation
49-
buffer (default: BFloat16)
49+
buffer (default: BFloat16). Only used if
50+
``use_kahan_summation=True``.
5051
5152
# Usage
5253
This optimizer implements optimizer states, and Kahan summation
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
from copy import deepcopy
9+
10+
import torch
11+
import torch.nn as nn
12+
import torch.optim as optim
13+
from torch.testing._internal.common_utils import (
14+
TestCase,
15+
instantiate_parametrized_tests,
16+
parametrize,
17+
run_tests,
18+
)
19+
20+
from torchdistx.optimizers import AnyPrecisionAdamW
21+
22+
23+
class TestAnyPrecisionOptimizer(TestCase):
24+
def _test_adam_equivalence(self, model, model_clone):
25+
# Test non-default options
26+
betas = (0.8, 0.88)
27+
weight_decay = 0.03
28+
29+
adam_opt = optim.AdamW(
30+
model_clone.parameters(), betas=betas, weight_decay=weight_decay
31+
)
32+
anyprecision_adam = AnyPrecisionAdamW(
33+
model.parameters(),
34+
variance_dtype=torch.float32,
35+
betas=betas,
36+
weight_decay=weight_decay,
37+
)
38+
39+
# Verify params are equal initially
40+
model_orig_params = [p.clone() for p in model.parameters()]
41+
for p1, p2 in zip(model_clone.parameters(), model_orig_params):
42+
self.assertEqual(p1, p2)
43+
44+
for i in range(6):
45+
adam_opt.zero_grad()
46+
anyprecision_adam.zero_grad()
47+
inp = torch.randn(5, 5, device=next(model.parameters()).device)
48+
model(inp).sum().backward()
49+
model_clone(inp).sum().backward()
50+
adam_opt.step()
51+
anyprecision_adam.step()
52+
53+
# Ensure params are modified from original
54+
if i == 0:
55+
for p1, p2 in zip(model.parameters(), model_orig_params):
56+
self.assertNotEqual(p1, p2)
57+
58+
for p1, p2 in zip(model.parameters(), model_clone.parameters()):
59+
self.assertEqual(p1, p2)
60+
61+
@parametrize("device", ["cpu", "cuda"])
62+
def test_adam_equivalence(self, device):
63+
"""
64+
Tests that AnyPrecisionAdamW is equivalent to AdamW when
65+
kahan summation and different dtypes for momentum, variance,
66+
and compensation buffer are turned off (i.e. all float32).
67+
"""
68+
if device == "cuda" and not torch.cuda.is_available():
69+
raise unittest.SkipTest("CUDA not available")
70+
71+
model = nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5), nn.Linear(5, 5))
72+
if device == "cuda":
73+
model.cuda()
74+
75+
model_clone = deepcopy(model)
76+
77+
self._test_adam_equivalence(model, model_clone)
78+
79+
80+
instantiate_parametrized_tests(TestAnyPrecisionOptimizer)
81+
82+
if __name__ == "__main__":
83+
run_tests()

0 commit comments

Comments
 (0)