from proxtoolbox.experiments.experiment import Experiment
from proxtoolbox import proxoperators
from proxtoolbox.utils.cell import Cell, isCell
#for downloading data
import proxtoolbox.utils.GetData as GetData
import numpy as np
from scipy.io import loadmat
from numpy.linalg import norm
from math import sqrt
import matplotlib.pyplot as plt
from matplotlib.pyplot import subplots, show, figure
[docs]class ART_Experiment(Experiment):
"""
ART experiment class
"""
[docs] @staticmethod
def getDefaultParameters():
defaultParams = {
'experiment_name': 'ART',
'object': 'complex',
'constraint': 'convex',
'rescale': True,
'MAXIT': 20,
'TOL': -1e-6,
'lambda_0': 0.75,
'lambda_max': 0.75,
'lambda_switch': 13,
'data_ball': 1e-15,
'diagnostic': True,
'iterate_monitor_name': 'CT_IterateMonitor',
'rotate': False,
'verbose': 0,
'graphics': 1,
'anim': 2,
'debug': True
}
return defaultParams
def __init__(self,
rescale = True,
**kwargs):
"""
"""
# call parent's __init__ method
super(ART_Experiment, self).__init__(**kwargs)
self.rescale = rescale
# do here any data member initialization
self.inner_dimension = None
self.outer_dimension = None
self.block_step = None
self.A = None
self.b = None
self.aaT = None
self.m = None
self.n = None
[docs] def loadData(self):
"""
Load ART dataset. Create the initial iterate.
"""
#make sure input data can be found, otherwise download it
GetData.getData('CT')
# load data
print('Loading data file ART_SheppLogan.mat ')
data_Shepp = loadmat('../InputData/CT/ART_SheppLogan.mat')
N = data_Shepp['N'].item()
p = data_Shepp['p'].item()
theta = data_Shepp['theta']
b_ex = data_Shepp['b_ex']
self.b = b_ex.reshape(b_ex.size) # make it a 1D array instead of 2D
self.A = data_Shepp['A'].toarray() #otherwise A is csc sparse scipy matrix
self.Ny = N**2
self.inner_dimension = p
self.outer_dimension = theta.size
self.Nx = 1
self.Nz = 1
self.block_step = p
# the next is a generic scaling
# that removes the dependence of the
# norms from the problem dimension.
# More specific normalizations are up to the user.
self.norm_data = np.sqrt(self.Ny)
if self.rescale:
tmp = 1/(np.diag(self.A @ self.A.T)+1e-20)
self.A = np.diag(tmp) @ self.A
self.b = self.b*tmp
self.aaT = np.zeros(len(self.b))
for k in range(len(self.b)):
a = self.A[k,:]
self.aaT[k] = a @ a.T + 1e-20
self.sets = self.block_step * self.outer_dimension
if self.formulation == 'product space':
self.product_space_dimension = self.sets
self.u0 = Cell(self.product_space_dimension)
for j in range(self.product_space_dimension):
self.u0[j] = np.zeros(self.Ny)
else:
self.u0 = np.zeros(self.Ny)
self.product_space_dimension = 1
self.m = self.A.shape[0]
self.n = self.A.shape[1]
[docs] def setupProxOperators(self):
"""
Determine the prox operators to be used for this experiment
"""
super(ART_Experiment, self).setupProxOperators() # call parent's method
self.proxOperators = []
self.productProxOperators = []
if self.formulation == 'product space':
self.nProx = 2
self.proxOperators.append('P_diag')
self.proxOperators.append('Prox_product_space')
self.n_product_Prox = self.product_space_dimension
for _j in range(self.n_product_Prox):
self.productProxOperators.append('P_hyperplane') #TODO maybe change to P_parallel_hyperplane
else:
self.nProx = self.sets
self.product_space_dimension = 1
for _j in range(self.nProx):
self.proxOperators.append('P_hyperplane')
[docs] def setFattening(self):
"""
Optional method for fattening/regularizing sets
Called by initialization() method after
instanciating prox operators but before instanciating
the algorithm
"""
# Estimate the gap in the relevant metric
computeGap = False # deactivated for now as it takes too long to compute
if computeGap:
# simple for now...
if self.formulation == 'product space':
proxOps = self.productProxOperators
u0 = self.u0[0]
else:
proxOps = self.proxOperators
u0 = self.u0
prox = proxOps[0](self)
u1 = prox.eval(u0, 0)
u2 = u1
tmp_gap = 0
for j in range(1, len(proxOps)):
prox = proxOps[j](self)
u2 = prox.eval(u1, j)
tmp_gap += (norm(u1 - u2)/self.norm_data)**2
else:
# use cached value
tmp_gap = 131.26250306395025 # u2 = prox.eval(u2, j)
# tmp_gap = 0.7668342258350885 # u2 = prox.eval(u1, j)
gap_0 = sqrt(tmp_gap)
# sets the set fattening to be a percentage of the
# initial gap to the unfattened set with
# respect to the relevant metric (KL or L2),
# that percentage given by
# input.data_ball input by the user.
self.data_ball = self.data_ball*gap_0
# the second tolerance relative to the oder of
# magnitude of the metric
self.TOL2 = self.data_ball*1e-15
[docs] def show(self):
u0 = self.u0
u_m = self.output['u_monitor']
if isCell(u_m):
u = u_m[0]
u2 = u2 = u_m[len(u_m)-1]
else:
u = self.output['u']
u2 = u_m
if isCell(u):
u = u[0]
elif u.ndim == 2:
u = u[:,0]
N = int(np.sqrt(u.shape[0]))
u = np.reshape(u, (N,N), order='F')
u2 = np.reshape(u2, (N,N), order='F')
# figure(900)
f, ((ax1, ax2), (ax3, ax4)) = subplots(2, 2,
figsize=(self.figure_width, self.figure_height),
dpi=self.figure_dpi)
self.createImageSubFigure(f, ax1, u, 'best approximation- physical domain')
self.createImageSubFigure(f, ax2, u2, 'best approximation - data constraint')
changes = self.output['stats']['changes']
time = self.output['stats']['time']
time_str = "{:.{}f}".format(time, 5) # 5 is precision
xLabel = "Iterations (time = " + time_str + " s)"
algo_desc = self.algorithm.getDescription()
title = "Algorithm " + algo_desc
ax3.plot(changes)
ax3.set_yscale('log')
ax3.set_xlabel(xLabel)
ax3.set_ylabel('log of iterate difference')
ax3.set_title(title)
if self.diagnostic and 'gaps' in self.output['stats']:
gaps = self.output['stats']['gaps']
ax4.semilogy(gaps)
ax4.set_xlabel(xLabel)
ax4.set_ylabel('Log of the gap distance')
else:
ax4.remove()
show()