Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ts_algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __get_version():
from .sirt import sirt
from .em import em
from .fbp import fbp
from .tv_min import tv_min2d
from .tv_min import tv_min2d, l2con_tv_min2d
from .operators import operator_norm, ATA_max_eigenvalue
from .fdk import fdk
from .nag_ls import nag_ls
Expand Down
99 changes: 97 additions & 2 deletions ts_algorithms/tv_min.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,9 @@ def tv_min2d(A, y, lam, num_iterations=500, L=None, non_negativity=False, progre
p = torch.zeros(A.range_shape, device=dev)
q = grad_2D(u) # contains zeros (and has correct shape)
u_avg = torch.clone(u)

for iteration in tqdm.trange(num_iterations, disable=not progress_bar):

pbar = tqdm.trange(num_iterations, disable=not progress_bar)
for iteration in pbar:
p = (p + s * (A(u_avg) - y)) / (1 + s)
q = clip(q + s * grad_2D(u_avg), lam)
u_new = u - (t * A.T(p) + t * grad_2D_T(q))
Expand All @@ -129,6 +130,100 @@ def tv_min2d(A, y, lam, num_iterations=500, L=None, non_negativity=False, progre
u_avg = u_new + theta * (u_new - u)
u = u_new

if progress_bar: # Only compute convergence statistics when they will actually be shown
primal_obj = torch.norm(A(u)-y)**2 + lam*torch.norm(grad_2D(u_avg), p=1)
dual_obj = -0.5*torch.norm(p)**2 - torch.inner(p.flatten(), y.flatten())
pbar.set_postfix(duality_gap = (primal_obj - dual_obj).item())


# Call all callbacks and stop iterating if one of the callbacks
# indicates to stop
if call_all_callbacks(callbacks, u, iteration):
break

return u


def l2con_tv_min2d(A, y, eps, num_iterations=500, tradeoff = 1.0, L=None, progress_bar=False, callbacks=()):
"""Computes the total-variation minimization using Chambolle-Pock

Assumes that the data is a single 2D slice. A 3D version with 3D
gradients is work in progress.

:param A: `tomosipo.Operator`
Projection operator
:param y: `torch.Tensor`
Projection data
:param eps: `float`
Upper bound on the residual.
:param num_iterations: `int`
Number of iterations
:param tradeoff: `float`
Trade-off between the primal and dual steps of Chambolle-Pock.
Default: 1.0
:param L:
operator norm of operator A
:param progress_bar: `bool`
Whether to show a progress bar on the command line interface.
Default: False
:param callbacks:
Iterable containing functions or callable objects. Each callback will
be called every iteration with the current estimate and iteration
number as arguments. If any callback returns True, the algorithm stops
after this iteration. This can be used for logging, tracking or
alternative stopping conditions.
:returns:
:rtype:

"""

dev = y.device

# Normalize the forward operator and data to improve conditioning.
# See comment in 'tv_min2d' above.
scale = operator_norm(A)
S = ts.scale(1 / scale, pos=A.domain.pos)
A = ts.operator(S * A.domain, S * A.range.to_vec())
y = y / scale
eps = eps / scale

if L is None:
L = operator_norm_plus_grad(A, num_iter=100)

t = tradeoff / L
s = 1.0 / (L*tradeoff)
theta = 1

u = torch.zeros(A.domain_shape, device=dev)
p = torch.zeros(A.range_shape, device=dev)
q = grad_2D(u) # contains zeros (and has correct shape)
u_avg = torch.clone(u)


pbar = tqdm.trange(num_iterations, disable=not progress_bar)
for iteration in pbar:

"""
Algorithm 7 in the paper cited at the top of this file suggest:
p = max(torch.norm(p + s * (A(u_avg)- y)) - s * eps, 0.0)*(p + s * (A(u_avg)- y)).
However, this formula/proximal operator is incorrect.
"""

p = (1 - s * eps/(max(torch.norm(p + s * (A(u_avg)- y)), s*eps)))*(p + s * (A(u_avg)- y))
q = clip(q + s * grad_2D(u_avg), 1)


u_new = u - (t * A.T(p) + t * grad_2D_T(q))
u_avg = u_new + theta * (u_new - u)
u = u_new

if progress_bar: # Only compute convergence statistics when they will actually be shown
primal_obj = torch.norm(grad_2D(u))
dual_obj = -eps*torch.norm(p) - torch.inner(p.flatten(), y.flatten())
residual = torch.norm(A(u) - y)
pbar.set_postfix(duality_gap = (primal_obj - dual_obj).item(),
residual = residual.item())

# Call all callbacks and stop iterating if one of the callbacks
# indicates to stop
if call_all_callbacks(callbacks, u, iteration):
Expand Down