Skip to content

Commit 6c1e1f3

Browse files
authored
[MRG] Update tests and documentation (#484)
* remove old macos and windows tets update requirements * speedup ssw and continuaous ot exmaples * speedup regpath and variane * speedup conv 2d example + continuous stick * speedup regpath
1 parent 5faa4fb commit 6c1e1f3

File tree

11 files changed

+44
-41
lines changed

11 files changed

+44
-41
lines changed

.github/workflows/build_tests.yml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ jobs:
8383
pip install -e .
8484
- name: Run tests
8585
run: |
86-
python -m pytest --durations=20 -v test/ ot/ --ignore ot/gpu/ --color=yes
86+
python -m pytest --durations=20 -v test/ ot/ --color=yes
8787
8888
8989
macos:
@@ -92,7 +92,7 @@ jobs:
9292
strategy:
9393
max-parallel: 4
9494
matrix:
95-
python-version: ["3.7", "3.8", "3.9", "3.10"]
95+
python-version: ["3.10"]
9696

9797
steps:
9898
- uses: actions/checkout@v1
@@ -107,10 +107,10 @@ jobs:
107107
run: |
108108
python -m pip install --upgrade pip
109109
pip install -r requirements.txt
110-
pip install pytest "pytest-cov<2.6"
110+
pip install pytest
111111
- name: Run tests
112112
run: |
113-
python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
113+
python -m pytest --durations=20 -v test/ ot/ --color=yes
114114
115115
116116
windows:
@@ -119,7 +119,7 @@ jobs:
119119
strategy:
120120
max-parallel: 4
121121
matrix:
122-
python-version: ["3.7", "3.8", "3.9", "3.10"]
122+
python-version: ["3.10"]
123123

124124
steps:
125125
- uses: actions/checkout@v1
@@ -151,8 +151,8 @@ jobs:
151151
- name: Install dependencies
152152
run: |
153153
python -m pip install -r .github/requirements_test_windows.txt
154-
python -m pip install torch==1.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
155-
python -m pip install pytest "pytest-cov<2.6"
154+
python -m pip3 install torch torchvision torchaudio
155+
python -m pip install pytest
156156
- name: Run tests
157157
run: |
158-
python -m pytest --durations=20 -v test/ ot/ --doctest-modules --ignore ot/gpu/ --cov=ot --color=yes
158+
python -m pytest --durations=20 -v test/ ot/ --color=yes

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
88
- Add tests on GPU for master branch and approved PR (PR #473)
99
- Add `median` method to all inherited classes of `backend.Backend` (PR #472)
10+
- Update tests for macOS and Windows, speedup documentation (PR #484)
1011

1112
#### Closed issues
1213

docs/Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ help:
5050
.PHONY: clean
5151
clean:
5252
rm -rf $(BUILDDIR)/*
53+
rm -rf source/gen_modules/*
54+
rm -rf source/auto_examples/*
5355

5456
.PHONY: html
5557
html:

examples/backends/plot_sliced_wass_grad_flow_pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@
4040
import ot
4141
import matplotlib.animation as animation
4242

43-
I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::4, ::4, 2]
44-
I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::4, ::4, 2]
43+
I1 = pl.imread('../../data/redcross.png').astype(np.float64)[::5, ::5, 2]
44+
I2 = pl.imread('../../data/tooth.png').astype(np.float64)[::5, ::5, 2]
4545

4646
sz = I2.shape[0]
4747
XX, YY = np.meshgrid(np.arange(sz), np.arange(sz))
@@ -67,7 +67,7 @@
6767

6868

6969
lr = 1e3
70-
nb_iter_max = 100
70+
nb_iter_max = 50
7171

7272
x_all = np.zeros((nb_iter_max, x1.shape[0], 2))
7373

@@ -129,7 +129,7 @@ def _update_plot(i):
129129
xbary_torch = torch.tensor(xbinit).to(device=device).requires_grad_(True)
130130

131131
lr = 1e3
132-
nb_iter_max = 100
132+
nb_iter_max = 50
133133

134134
x_all = np.zeros((nb_iter_max, xbary_torch.shape[0], 2))
135135

examples/backends/plot_ssw_unif_torch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535

3636
torch.manual_seed(1)
3737

38-
N = 1000
38+
N = 500
3939
x0 = torch.rand(N, 3)
4040
x0 = F.normalize(x0, dim=-1)
4141

@@ -72,8 +72,8 @@ def plot_sphere(ax):
7272
x = x0.clone()
7373
x.requires_grad_(True)
7474

75-
n_iter = 500
76-
lr = 100
75+
n_iter = 100
76+
lr = 150
7777

7878
losses = []
7979
xvisu = torch.zeros(n_iter, N, 3)
@@ -82,7 +82,7 @@ def plot_sphere(ax):
8282
sw = ot.sliced_wasserstein_sphere_unif(x, n_projections=500)
8383
grad_x = torch.autograd.grad(sw, x)[0]
8484

85-
x = x - lr * grad_x
85+
x = x - lr * grad_x / np.sqrt(i / 10 + 1)
8686
x = F.normalize(x, p=2, dim=1)
8787

8888
losses.append(sw.item())
@@ -102,7 +102,7 @@ def plot_sphere(ax):
102102
# Plot trajectories of generated samples along iterations
103103
# -------------------------------------------------------
104104

105-
ivisu = [0, 25, 50, 75, 100, 150, 200, 350, 499]
105+
ivisu = [0, 10, 20, 30, 40, 50, 60, 70, 80]
106106

107107
fig = pl.figure(3, (10, 10))
108108
for i in range(9):
@@ -149,5 +149,5 @@ def _update_plot(i):
149149
ax.set_title('Iter. {}'.format(ivisu[i]))
150150

151151

152-
ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=100, repeat_delay=2000)
152+
ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter // 5, interval=200, repeat_delay=2000)
153153
# %%

examples/backends/plot_stoch_continuous_ot_pytorch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
torch.manual_seed(42)
2828
np.random.seed(42)
2929

30-
n_source_samples = 10000
31-
n_target_samples = 10000
30+
n_source_samples = 1000
31+
n_target_samples = 1000
3232
theta = 2 * np.pi / 20
3333
noise_level = 0.1
3434

@@ -89,7 +89,7 @@ def forward(self, x):
8989
optimizer = torch.optim.Adam(list(u.parameters()) + list(v.parameters()), lr=.005)
9090

9191
# number of iteration
92-
n_iter = 1000
92+
n_iter = 500
9393
n_batch = 500
9494

9595

examples/barycenters/plot_convolutional_barycenter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@
2929
this_file = os.path.realpath('__file__')
3030
data_path = os.path.join(Path(this_file).parent.parent.parent, 'data')
3131

32-
f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[:, :, 2]
33-
f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[:, :, 2]
34-
f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[:, :, 2]
35-
f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[:, :, 2]
32+
f1 = 1 - plt.imread(os.path.join(data_path, 'redcross.png'))[::2, ::2, 2]
33+
f2 = 1 - plt.imread(os.path.join(data_path, 'tooth.png'))[::2, ::2, 2]
34+
f3 = 1 - plt.imread(os.path.join(data_path, 'heart.png'))[::2, ::2, 2]
35+
f4 = 1 - plt.imread(os.path.join(data_path, 'duck.png'))[::2, ::2, 2]
3636

3737
f1 = f1 / np.sum(f1)
3838
f2 = f2 / np.sum(f2)

examples/plot_OT_1D_smooth.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
#
1515
# License: MIT License
1616

17-
# sphinx_gallery_thumbnail_number = 6
17+
# sphinx_gallery_thumbnail_number = 5
1818

1919
import numpy as np
2020
import matplotlib.pylab as pl

examples/sliced-wasserstein/plot_variance.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
# %% parameters and data generation
3131

32-
n = 500 # nb samples
32+
n = 200 # nb samples
3333

3434
mu_s = np.array([0, 0])
3535
cov_s = np.array([[1, 0], [0, 1]])
@@ -58,9 +58,9 @@
5858
# Sliced Wasserstein distance for different seeds and number of projections
5959
# -------------------------------------------------------------------------
6060

61-
n_seed = 50
62-
n_projections_arr = np.logspace(0, 3, 25, dtype=int)
63-
res = np.empty((n_seed, 25))
61+
n_seed = 20
62+
n_projections_arr = np.logspace(0, 3, 10, dtype=int)
63+
res = np.empty((n_seed, 10))
6464

6565
# %% Compute statistics
6666
for seed in range(n_seed):

examples/sliced-wasserstein/plot_variance_ssw.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
# %% parameters and data generation
3030

31-
n = 500 # nb samples
31+
n = 200 # nb samples
3232

3333
xs = np.random.randn(n, 3)
3434
xt = np.random.randn(n, 3)
@@ -81,9 +81,9 @@
8181
# Spherical Sliced Wasserstein for different seeds and number of projections
8282
# --------------------------------------------------------------------------
8383

84-
n_seed = 50
85-
n_projections_arr = np.logspace(0, 3, 25, dtype=int)
86-
res = np.empty((n_seed, 25))
84+
n_seed = 20
85+
n_projections_arr = np.logspace(0, 3, 10, dtype=int)
86+
res = np.empty((n_seed, 10))
8787

8888
# %% Compute statistics
8989
for seed in range(n_seed):

0 commit comments

Comments
 (0)