From 6228586bdad9db286cd9bd4f951f4ef03c978a8c Mon Sep 17 00:00:00 2001 From: Jasper Everink Date: Thu, 6 Mar 2025 13:10:59 +0100 Subject: [PATCH 1/3] Add l1 constrained TV minimization algorithm --- ts_algorithms/tv_min.py | 85 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/ts_algorithms/tv_min.py b/ts_algorithms/tv_min.py index 4bb49b7..fd11404 100644 --- a/ts_algorithms/tv_min.py +++ b/ts_algorithms/tv_min.py @@ -135,3 +135,88 @@ def tv_min2d(A, y, lam, num_iterations=500, L=None, non_negativity=False, progre break return u + + +def l2con_tv_min2d(A, y, eps, num_iterations=500, tradeoff = 1.0, L=None, progress_bar=False, callbacks=()): + """Computes the total-variation minimization using Chambolle-Pock + + Assumes that the data is a single 2D slice. A 3D version with 3D + gradients is work in progress. + + :param A: `tomosipo.Operator` + Projection operator + :param y: `torch.Tensor` + Projection data + :param eps: `float` + Upper bound on the residual. + :param num_iterations: `int` + Number of iterations + :param tradeoff: `float` + Trade-off between the primal and dual steps of Chambolle-Pock. + Default: 1.0 + :param L: + operator norm of operator A + :param progress_bar: `bool` + Whether to show a progress bar on the command line interface. + Default: False + :param callbacks: + Iterable containing functions or callable objects. Each callback will + be called every iteration with the current estimate and iteration + number as arguments. If any callback returns True, the algorithm stops + after this iteration. This can be used for logging, tracking or + alternative stopping conditions. + :returns: + :rtype: + + """ + + dev = y.device + + # Normalize the forward operator and data to improve conditioning. + # See comment in 'tv_min2d' above. + scale = operator_norm(A) + S = ts.scale(1 / scale, pos=A.domain.pos) + A = ts.operator(S * A.domain, S * A.range.to_vec()) + y = y / scale + eps = eps / scale + + if L is None: + L = operator_norm_plus_grad(A, num_iter=100) + + t = tradeoff / L + s = 1.0 / (L*tradeoff) + theta = 1 + + u = torch.zeros(A.domain_shape, device=dev) + p = torch.zeros(A.range_shape, device=dev) + q = grad_2D(u) # contains zeros (and has correct shape) + u_avg = torch.clone(u) + + for iteration in tqdm.trange(num_iterations, disable=not progress_bar): + + """ + Algorithm 7 in the paper cited at the top of this file suggest: + p = max(torch.norm(p + s * (A(u_avg)- y)) - s * eps, 0.0)*(p + s * (A(u_avg)- y)). + However, this formula/proximal operator is incorrect. + """ + + p = (1 - s * eps/(max(torch.norm(p + s * (A(u_avg)- y)), s*eps)))*(p + s * (A(u_avg)- y)) + q = clip(q + s * grad_2D(u_avg), 1) + + + # TODO Print duality gap and feasibility as part of the tqdm progress bar + #primal = torch.norm(grad_2D(u)) + #dual = -eps*torch.norm(p) - torch.inner(p.flatten(), y.flatten()) + #print(primal - dual) + + u_new = u - (t * A.T(p) + t * grad_2D_T(q)) + u_avg = u_new + theta * (u_new - u) + u = u_new + + # Call all callbacks and stop iterating if one of the callbacks + # indicates to stop + if call_all_callbacks(callbacks, u, iteration): + break + + return u + \ No newline at end of file From 0308a403f2fff96522bb885d35495ecb84e1ced1 Mon Sep 17 00:00:00 2001 From: Jasper Everink Date: Thu, 6 Mar 2025 13:19:54 +0100 Subject: [PATCH 2/3] Add convergence statistics to l2con_tv_min2d --- ts_algorithms/tv_min.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/ts_algorithms/tv_min.py b/ts_algorithms/tv_min.py index fd11404..5c78d0f 100644 --- a/ts_algorithms/tv_min.py +++ b/ts_algorithms/tv_min.py @@ -192,7 +192,9 @@ def l2con_tv_min2d(A, y, eps, num_iterations=500, tradeoff = 1.0, L=None, progre q = grad_2D(u) # contains zeros (and has correct shape) u_avg = torch.clone(u) - for iteration in tqdm.trange(num_iterations, disable=not progress_bar): + + pbar = tqdm.trange(num_iterations, disable=not progress_bar) + for iteration in pbar: """ Algorithm 7 in the paper cited at the top of this file suggest: @@ -203,20 +205,21 @@ def l2con_tv_min2d(A, y, eps, num_iterations=500, tradeoff = 1.0, L=None, progre p = (1 - s * eps/(max(torch.norm(p + s * (A(u_avg)- y)), s*eps)))*(p + s * (A(u_avg)- y)) q = clip(q + s * grad_2D(u_avg), 1) - - # TODO Print duality gap and feasibility as part of the tqdm progress bar - #primal = torch.norm(grad_2D(u)) - #dual = -eps*torch.norm(p) - torch.inner(p.flatten(), y.flatten()) - #print(primal - dual) u_new = u - (t * A.T(p) + t * grad_2D_T(q)) u_avg = u_new + theta * (u_new - u) u = u_new + if progress_bar: # Only compute convergence statistics when they will actually be shown + primal_obj = torch.norm(grad_2D(u)) + dual_obj = -eps*torch.norm(p) - torch.inner(p.flatten(), y.flatten()) + residual = torch.norm(A(u) - y) + pbar.set_postfix(duality_gap = primal_obj - dual_obj, + residual = residual) + # Call all callbacks and stop iterating if one of the callbacks # indicates to stop if call_all_callbacks(callbacks, u, iteration): break return u - \ No newline at end of file From f8c33be05770e54920c73a6ab5adee710f2f1eb0 Mon Sep 17 00:00:00 2001 From: Jasper Everink Date: Fri, 7 Mar 2025 08:56:05 +0100 Subject: [PATCH 3/3] Print progress fix --- ts_algorithms/__init__.py | 2 +- ts_algorithms/tv_min.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/ts_algorithms/__init__.py b/ts_algorithms/__init__.py index 3c4b0b9..2a8ade3 100644 --- a/ts_algorithms/__init__.py +++ b/ts_algorithms/__init__.py @@ -21,7 +21,7 @@ def __get_version(): from .sirt import sirt from .fbp import fbp -from .tv_min import tv_min2d +from .tv_min import tv_min2d, l2con_tv_min2d from .operators import operator_norm, ATA_max_eigenvalue from .fdk import fdk from .nag_ls import nag_ls diff --git a/ts_algorithms/tv_min.py b/ts_algorithms/tv_min.py index 5c78d0f..e59f3e0 100644 --- a/ts_algorithms/tv_min.py +++ b/ts_algorithms/tv_min.py @@ -119,8 +119,9 @@ def tv_min2d(A, y, lam, num_iterations=500, L=None, non_negativity=False, progre p = torch.zeros(A.range_shape, device=dev) q = grad_2D(u) # contains zeros (and has correct shape) u_avg = torch.clone(u) - - for iteration in tqdm.trange(num_iterations, disable=not progress_bar): + + pbar = tqdm.trange(num_iterations, disable=not progress_bar) + for iteration in pbar: p = (p + s * (A(u_avg) - y)) / (1 + s) q = clip(q + s * grad_2D(u_avg), lam) u_new = u - (t * A.T(p) + t * grad_2D_T(q)) @@ -129,6 +130,12 @@ def tv_min2d(A, y, lam, num_iterations=500, L=None, non_negativity=False, progre u_avg = u_new + theta * (u_new - u) u = u_new + if progress_bar: # Only compute convergence statistics when they will actually be shown + primal_obj = torch.norm(A(u)-y)**2 + lam*torch.norm(grad_2D(u_avg), p=1) + dual_obj = -0.5*torch.norm(p)**2 - torch.inner(p.flatten(), y.flatten()) + pbar.set_postfix(duality_gap = (primal_obj - dual_obj).item()) + + # Call all callbacks and stop iterating if one of the callbacks # indicates to stop if call_all_callbacks(callbacks, u, iteration): @@ -214,8 +221,8 @@ def l2con_tv_min2d(A, y, eps, num_iterations=500, tradeoff = 1.0, L=None, progre primal_obj = torch.norm(grad_2D(u)) dual_obj = -eps*torch.norm(p) - torch.inner(p.flatten(), y.flatten()) residual = torch.norm(A(u) - y) - pbar.set_postfix(duality_gap = primal_obj - dual_obj, - residual = residual) + pbar.set_postfix(duality_gap = (primal_obj - dual_obj).item(), + residual = residual.item()) # Call all callbacks and stop iterating if one of the callbacks # indicates to stop