Behind the scenes

The focus of this section is to describe how a new experiment class can be implemented.

Creating a new experiment class

Such a class must be derived from the abstract Experiment class or possibly the PhaseExperiment class. The latter is a specialization of the former which is used to model phase retrieval problems. The Experiment class provides the infrastructure that is required to develop a new experiment. What remains to be done is to override a few methods.

As an example of implementation, we will use the CDI_Experiment class. This class inherits from the PhaseExperiment class:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
class CDI_Experiment(PhaseExperiment):
   """
   CDI experiment class
   """

   def __init__(self,
              warmup_iter=0,
              **kwargs):
      # call parent's __init__ method
      super(CDI_Experiment, self).__init__(**kwargs)

      # do here any data member initialization
      self.warmup_iter = warmup_iter

      # the following data members are set by loadData()
      self.magn = None
      self.farfield = None
      self.data_zeros = None
      self.support_idx = None
      self.abs_illumination = None
      self.supp_phase = None

Although the Experiment class provides most of the required data members, most concrete classes will define additional attributes.

Loading the data

A concrete experiment class needs to override the loadData() method:

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
def loadData(self):
   """
   Load CDI dataset. Create the initial iterate.
   """

   # make sure input data can be found, otherwise download it
   GetData.getData('Phase')

   print('Loading data file CDI_intensity')
   f = loadmat('../InputData/Phase/CDI_intensity.mat')
   # diffraction pattern
   dp = f['intensity']['img'][0, 0]
   orig_res = max(dp.shape[0], dp.shape[1])  # actual data size
   step_up = ceil(log2(self.Nx) - log2(orig_res))
   workres = 2**(step_up) * 2**(floor(log2(orig_res)))  # desired array size
   N = int(workres)

   [...]

   # initial iterate
   tmp_rnd = (np.random.rand(N, N)).T # to match Matlab
   self.u0 = S * tmp_rnd
   self.u0 = self.u0 / np.linalg.norm(self.u0, 'fro') * self.norm_rt_data

The role of this method is to load or generate the dataset that will be used for this experiment. It must also create the initial iterate.

Setting-up the prox operators

Each experiment class needs to specify which prox operators are required. In this case, some of the work has already been done in the parent class, based on the given constraint parameter.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
def setupProxOperators(self):
   """
   Determine the prox operators to be used for this experiment
   """
   super(CDI_Experiment, self).setupProxOperators()  # call parent's method

   self.propagator = 'Propagator_FreFra'
   self.inverse_propagator = 'InvPropagator_FreFra'

   # remark: self.farfield is always true (set in data processor)
   # self.proxOperators already contains a prox operator at slot 0.
   # Here, we add the second one.
   if self.constraint == 'phaselift':
      self.proxOperators.append('P_Rank1')
   elif self.constraint == 'phaselift2':
      self.proxOperators.append('P_rank1_SR')
   else:
      self.proxOperators.append('Approx_Pphase_FreFra_Poisson')
   self.nProx = self.sets

The setupProxOperators() method must fill the attribute proxOperators which contains the prox operators that are needed for this experiment. It is a list of strings where each string corresponds to the name of a prox operator class. Based on these names, the actual prox operator instances will be automatically created later on during the initialization process. The attribute nProx indicates how many prox operators are used. Setting the propagator and inverse_propagator attributes is required by the Approx_Pphase_FreFra_Poisson prox operator.

The following code comes from the JWST_Experiment class and gives another example which takes into account a product space formulation.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def setupProxOperators(self):
   """
   Determine the prox operators to be used for this experiment
   """
   super(JWST_Experiment, self).setupProxOperators() # call parent's method

   self.proxOperators = []
   self.productProxOperators = []
   if self.formulation == 'cyclic':
      # there are as many prox operators as there are sets
      self.nProx = self.sets
      self.product_space_dimension = 1
      for _j in range(self.nProx-1):
            self.proxOperators.append('Approx_Pphase_FreFra_Poisson')
      self.proxOperators.append('P_amp_support')
   else: # product space formulation
      # add prox operators
      self.nProx = 2
      self.proxOperators.append('P_diag')
      self.proxOperators.append('Prox_product_space')
      # add product prox operators
      self.n_product_Prox = self.product_space_dimension
      for _j in range(self.n_product_Prox-1):
            self.productProxOperators.append('Approx_Pphase_FreFra_Poisson')
      self.productProxOperators.append('P_amp_support')

   # add propagator and inverse propagator
   # used by Approx_Pphase_FreFra_Poisson prox operator
   self.propagator = 'Propagator_FreFra'
   self.inverse_propagator = 'InvPropagator_FreFra'

In the case of a product space formulation, the attribute productProxOperators must also be filled. The attribute n_product_Prox indicates how many product prox operators are used.

Defining the default parameters

Any experiment class must provide default parameters by defining the static method getDefaultParameters()

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@staticmethod
def getDefaultParameters():
   defaultParams = {
      'experiment_name': 'CDI',
      'object': 'nonnegative',
      'constraint': 'nonnegative and support',
      'Nx': 128,
      'Ny': 128,
      'Nz': 1,
      'sets': 10,
      'farfield': True,
      'MAXIT': 6000,
      'TOL': 1e-8,
      'lambda_0': 0.5,
      'lambda_max': 0.50,
      'lambda_switch': 30,
      'data_ball': .999826e-30,
      'diagnostic': True,
      'iterate_monitor_name': 'FeasibilityIterateMonitor',
      'rotate': False,
      'verbose': 0,
      'graphics': 1,
      'anim': False,
      'debug': True
   }
   return defaultParams

Generating the graphical output

By implementing the show() method, each experiment can generate a graphical output from the obtained solution. In the case of the CDI_Experiment class most of the work is already done in the parent class, which provides a common implementation for all the phase retrieval experiments. Here, the show() method is overridden to display the initial far field data and support constraint.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
def show(self):
   """
   Generate graphical output from the solution
   """

   # display plot of far field data and support constraint
   # figure(123)
   f, (ax1, ax2) = subplots(1, 2,
                           figsize=(self.figure_width, self.figure_height),
                           dpi=self.figure_dpi)
   im = ax1.imshow(log10(self.dp + 1e-15))
   f.colorbar(im, ax=ax1, fraction=0.046, pad=0.04)
   ax1.set_title('Far field data')
   im = ax2.imshow(self.abs_illumination)
   ax2.set_title('Support constraint')
   plt.subplots_adjust(wspace=0.3)  # adjust horizontal space (width)
                                    # between subplots (default = 0.2)
   f.suptitle('CDI Data')

   # call parent to display the other plots
   super(CDI_Experiment, self).show()