Source code for proxtoolbox.experiments.CT.ART_Experiment

from proxtoolbox.experiments.experiment import Experiment
from proxtoolbox import proxoperators
from proxtoolbox.utils.cell import Cell, isCell

#for downloading data
import proxtoolbox.utils.GetData as GetData

import numpy as np
from scipy.io import loadmat
from numpy.linalg import norm
from math import sqrt

import matplotlib.pyplot as plt
from matplotlib.pyplot import subplots, show, figure


[docs]class ART_Experiment(Experiment): """ ART experiment class """
[docs] @staticmethod def getDefaultParameters(): defaultParams = { 'experiment_name': 'ART', 'object': 'complex', 'constraint': 'convex', 'rescale': True, 'MAXIT': 20, 'TOL': -1e-6, 'lambda_0': 0.75, 'lambda_max': 0.75, 'lambda_switch': 13, 'data_ball': 1e-15, 'diagnostic': True, 'iterate_monitor_name': 'CT_IterateMonitor', 'rotate': False, 'verbose': 0, 'graphics': 1, 'anim': 2, 'debug': True } return defaultParams
def __init__(self, rescale = True, **kwargs): """ """ # call parent's __init__ method super(ART_Experiment, self).__init__(**kwargs) self.rescale = rescale # do here any data member initialization self.inner_dimension = None self.outer_dimension = None self.block_step = None self.A = None self.b = None self.aaT = None self.m = None self.n = None
[docs] def loadData(self): """ Load ART dataset. Create the initial iterate. """ #make sure input data can be found, otherwise download it GetData.getData('CT') # load data print('Loading data file ART_SheppLogan.mat ') data_Shepp = loadmat('../InputData/CT/ART_SheppLogan.mat') N = data_Shepp['N'].item() p = data_Shepp['p'].item() theta = data_Shepp['theta'] b_ex = data_Shepp['b_ex'] self.b = b_ex.reshape(b_ex.size) # make it a 1D array instead of 2D self.A = data_Shepp['A'].toarray() #otherwise A is csc sparse scipy matrix self.Ny = N**2 self.inner_dimension = p self.outer_dimension = theta.size self.Nx = 1 self.Nz = 1 self.block_step = p # the next is a generic scaling # that removes the dependence of the # norms from the problem dimension. # More specific normalizations are up to the user. self.norm_data = np.sqrt(self.Ny) if self.rescale: tmp = 1/(np.diag(self.A @ self.A.T)+1e-20) self.A = np.diag(tmp) @ self.A self.b = self.b*tmp self.aaT = np.zeros(len(self.b)) for k in range(len(self.b)): a = self.A[k,:] self.aaT[k] = a @ a.T + 1e-20 self.sets = self.block_step * self.outer_dimension if self.formulation == 'product space': self.product_space_dimension = self.sets self.u0 = Cell(self.product_space_dimension) for j in range(self.product_space_dimension): self.u0[j] = np.zeros(self.Ny) else: self.u0 = np.zeros(self.Ny) self.product_space_dimension = 1 self.m = self.A.shape[0] self.n = self.A.shape[1]
[docs] def setupProxOperators(self): """ Determine the prox operators to be used for this experiment """ super(ART_Experiment, self).setupProxOperators() # call parent's method self.proxOperators = [] self.productProxOperators = [] if self.formulation == 'product space': self.nProx = 2 self.proxOperators.append('P_diag') self.proxOperators.append('Prox_product_space') self.n_product_Prox = self.product_space_dimension for _j in range(self.n_product_Prox): self.productProxOperators.append('P_hyperplane') #TODO maybe change to P_parallel_hyperplane else: self.nProx = self.sets self.product_space_dimension = 1 for _j in range(self.nProx): self.proxOperators.append('P_hyperplane')
[docs] def setFattening(self): """ Optional method for fattening/regularizing sets Called by initialization() method after instanciating prox operators but before instanciating the algorithm """ # Estimate the gap in the relevant metric computeGap = False # deactivated for now as it takes too long to compute if computeGap: # simple for now... if self.formulation == 'product space': proxOps = self.productProxOperators u0 = self.u0[0] else: proxOps = self.proxOperators u0 = self.u0 prox = proxOps[0](self) u1 = prox.eval(u0, 0) u2 = u1 tmp_gap = 0 for j in range(1, len(proxOps)): prox = proxOps[j](self) u2 = prox.eval(u1, j) tmp_gap += (norm(u1 - u2)/self.norm_data)**2 else: # use cached value tmp_gap = 131.26250306395025 # u2 = prox.eval(u2, j) # tmp_gap = 0.7668342258350885 # u2 = prox.eval(u1, j) gap_0 = sqrt(tmp_gap) # sets the set fattening to be a percentage of the # initial gap to the unfattened set with # respect to the relevant metric (KL or L2), # that percentage given by # input.data_ball input by the user. self.data_ball = self.data_ball*gap_0 # the second tolerance relative to the oder of # magnitude of the metric self.TOL2 = self.data_ball*1e-15
[docs] def show(self): u0 = self.u0 u_m = self.output['u_monitor'] if isCell(u_m): u = u_m[0] u2 = u2 = u_m[len(u_m)-1] else: u = self.output['u'] u2 = u_m if isCell(u): u = u[0] elif u.ndim == 2: u = u[:,0] N = int(np.sqrt(u.shape[0])) u = np.reshape(u, (N,N), order='F') u2 = np.reshape(u2, (N,N), order='F') # figure(900) f, ((ax1, ax2), (ax3, ax4)) = subplots(2, 2, figsize=(self.figure_width, self.figure_height), dpi=self.figure_dpi) self.createImageSubFigure(f, ax1, u, 'best approximation- physical domain') self.createImageSubFigure(f, ax2, u2, 'best approximation - data constraint') changes = self.output['stats']['changes'] time = self.output['stats']['time'] time_str = "{:.{}f}".format(time, 5) # 5 is precision xLabel = "Iterations (time = " + time_str + " s)" algo_desc = self.algorithm.getDescription() title = "Algorithm " + algo_desc ax3.plot(changes) ax3.set_yscale('log') ax3.set_xlabel(xLabel) ax3.set_ylabel('log of iterate difference') ax3.set_title(title) if self.diagnostic and 'gaps' in self.output['stats']: gaps = self.output['stats']['gaps'] ax4.semilogy(gaps) ax4.set_xlabel(xLabel) ax4.set_ylabel('Log of the gap distance') else: ax4.remove() show()
[docs] def createImageSubFigure(self, f, ax, u, title = None): im = ax.imshow(u, cmap='gray') # The "magic" values for fraction and pad adjust the # size of the color bar so that its height is comparable # to the plot: f.colorbar(im, ax=ax, fraction=0.046, pad=0.04) if title is not None: ax.set_title(title)