Behind the scenes¶
The focus of this section is to describe how a new experiment class can be implemented.
Creating a new experiment class¶
Such a class must be derived from the abstract Experiment
class
or possibly the PhaseExperiment
class. The latter is a
specialization of the former which is used to model phase retrieval
problems. The Experiment
class provides the infrastructure
that is required to develop a new experiment. What remains to be done
is to override a few methods.
As an example of implementation, we will use the
CDI_Experiment
class. This class inherits
from the PhaseExperiment
class:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | class CDI_Experiment(PhaseExperiment):
"""
CDI experiment class
"""
def __init__(self,
warmup_iter=0,
**kwargs):
# call parent's __init__ method
super(CDI_Experiment, self).__init__(**kwargs)
# do here any data member initialization
self.warmup_iter = warmup_iter
# the following data members are set by loadData()
self.magn = None
self.farfield = None
self.data_zeros = None
self.support_idx = None
self.abs_illumination = None
self.supp_phase = None
|
Although the Experiment
class provides most of the required
data members, most concrete classes will define
additional attributes.
Loading the data¶
A concrete experiment class needs to override the
loadData()
method:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 | def loadData(self):
"""
Load CDI dataset. Create the initial iterate.
"""
# make sure input data can be found, otherwise download it
GetData.getData('Phase')
print('Loading data file CDI_intensity')
f = loadmat('../InputData/Phase/CDI_intensity.mat')
# diffraction pattern
dp = f['intensity']['img'][0, 0]
orig_res = max(dp.shape[0], dp.shape[1]) # actual data size
step_up = ceil(log2(self.Nx) - log2(orig_res))
workres = 2**(step_up) * 2**(floor(log2(orig_res))) # desired array size
N = int(workres)
[...]
# initial iterate
tmp_rnd = (np.random.rand(N, N)).T # to match Matlab
self.u0 = S * tmp_rnd
self.u0 = self.u0 / np.linalg.norm(self.u0, 'fro') * self.norm_rt_data
|
The role of this method is to load or generate the dataset that will be used for this experiment. It must also create the initial iterate.
Setting-up the prox operators¶
Each experiment class needs to specify which prox operators are
required. In this case, some of the work has already been done
in the parent class, based on the given constraint
parameter.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | def setupProxOperators(self):
"""
Determine the prox operators to be used for this experiment
"""
super(CDI_Experiment, self).setupProxOperators() # call parent's method
self.propagator = 'Propagator_FreFra'
self.inverse_propagator = 'InvPropagator_FreFra'
# remark: self.farfield is always true (set in data processor)
# self.proxOperators already contains a prox operator at slot 0.
# Here, we add the second one.
if self.constraint == 'phaselift':
self.proxOperators.append('P_Rank1')
elif self.constraint == 'phaselift2':
self.proxOperators.append('P_rank1_SR')
else:
self.proxOperators.append('Approx_Pphase_FreFra_Poisson')
self.nProx = self.sets
|
The setupProxOperators()
method must fill the attribute
proxOperators
which contains the prox operators
that are needed for this experiment. It is a list of strings where
each string corresponds to the name of a prox operator class.
Based on these names, the actual prox operator instances will
be automatically created later on during the initialization process.
The attribute nProx
indicates how many prox operators are used.
Setting the propagator
and inverse_propagator
attributes
is required by the Approx_Pphase_FreFra_Poisson
prox operator.
The following code comes from the JWST_Experiment
class and
gives another example which takes into account a product space formulation.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 | def setupProxOperators(self):
"""
Determine the prox operators to be used for this experiment
"""
super(JWST_Experiment, self).setupProxOperators() # call parent's method
self.proxOperators = []
self.productProxOperators = []
if self.formulation == 'cyclic':
# there are as many prox operators as there are sets
self.nProx = self.sets
self.product_space_dimension = 1
for _j in range(self.nProx-1):
self.proxOperators.append('Approx_Pphase_FreFra_Poisson')
self.proxOperators.append('P_amp_support')
else: # product space formulation
# add prox operators
self.nProx = 2
self.proxOperators.append('P_diag')
self.proxOperators.append('Prox_product_space')
# add product prox operators
self.n_product_Prox = self.product_space_dimension
for _j in range(self.n_product_Prox-1):
self.productProxOperators.append('Approx_Pphase_FreFra_Poisson')
self.productProxOperators.append('P_amp_support')
# add propagator and inverse propagator
# used by Approx_Pphase_FreFra_Poisson prox operator
self.propagator = 'Propagator_FreFra'
self.inverse_propagator = 'InvPropagator_FreFra'
|
In the case of a product space formulation, the attribute productProxOperators
must also be filled. The attribute n_product_Prox
indicates how
many product prox operators are used.
Defining the default parameters¶
Any experiment class must provide default parameters
by defining the static method getDefaultParameters()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | @staticmethod
def getDefaultParameters():
defaultParams = {
'experiment_name': 'CDI',
'object': 'nonnegative',
'constraint': 'nonnegative and support',
'Nx': 128,
'Ny': 128,
'Nz': 1,
'sets': 10,
'farfield': True,
'MAXIT': 6000,
'TOL': 1e-8,
'lambda_0': 0.5,
'lambda_max': 0.50,
'lambda_switch': 30,
'data_ball': .999826e-30,
'diagnostic': True,
'iterate_monitor_name': 'FeasibilityIterateMonitor',
'rotate': False,
'verbose': 0,
'graphics': 1,
'anim': False,
'debug': True
}
return defaultParams
|
Generating the graphical output¶
By implementing the show()
method, each experiment can
generate a graphical output from the obtained solution.
In the case of the CDI_Experiment
class most of the
work is already done in the parent class, which provides
a common implementation for all the phase retrieval experiments.
Here, the show()
method is overridden to display the
initial far field data and support constraint.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | def show(self):
"""
Generate graphical output from the solution
"""
# display plot of far field data and support constraint
# figure(123)
f, (ax1, ax2) = subplots(1, 2,
figsize=(self.figure_width, self.figure_height),
dpi=self.figure_dpi)
im = ax1.imshow(log10(self.dp + 1e-15))
f.colorbar(im, ax=ax1, fraction=0.046, pad=0.04)
ax1.set_title('Far field data')
im = ax2.imshow(self.abs_illumination)
ax2.set_title('Support constraint')
plt.subplots_adjust(wspace=0.3) # adjust horizontal space (width)
# between subplots (default = 0.2)
f.suptitle('CDI Data')
# call parent to display the other plots
super(CDI_Experiment, self).show()
|