Behind the scenes ----------------- The focus of this section is to describe how a new experiment class can be implemented. .. contents:: :local: Creating a new experiment class +++++++++++++++++++++++++++++++ Such a class must be derived from the abstract :class:`Experiment` class or possibly the :class:`PhaseExperiment` class. The latter is a specialization of the former which is used to model phase retrieval problems. The :class:`Experiment` class provides the infrastructure that is required to develop a new experiment. What remains to be done is to override a few methods. As an example of implementation, we will use the :class:`CDI_Experiment` class. This class inherits from the :class:`PhaseExperiment` class: .. code-block:: python :linenos: class CDI_Experiment(PhaseExperiment): """ CDI experiment class """ def __init__(self, warmup_iter=0, **kwargs): # call parent's __init__ method super(CDI_Experiment, self).__init__(**kwargs) # do here any data member initialization self.warmup_iter = warmup_iter # the following data members are set by loadData() self.magn = None self.farfield = None self.data_zeros = None self.support_idx = None self.abs_illumination = None self.supp_phase = None Although the :class:`Experiment` class provides most of the required data members, most concrete classes will define additional attributes. Loading the data ++++++++++++++++ A concrete experiment class needs to override the :meth:`loadData` method: .. code-block:: python :linenos: def loadData(self): """ Load CDI dataset. Create the initial iterate. """ # make sure input data can be found, otherwise download it GetData.getData('Phase') print('Loading data file CDI_intensity') f = loadmat('../InputData/Phase/CDI_intensity.mat') # diffraction pattern dp = f['intensity']['img'][0, 0] orig_res = max(dp.shape[0], dp.shape[1]) # actual data size step_up = ceil(log2(self.Nx) - log2(orig_res)) workres = 2**(step_up) * 2**(floor(log2(orig_res))) # desired array size N = int(workres) [...] # initial iterate tmp_rnd = (np.random.rand(N, N)).T # to match Matlab self.u0 = S * tmp_rnd self.u0 = self.u0 / np.linalg.norm(self.u0, 'fro') * self.norm_rt_data The role of this method is to load or generate the dataset that will be used for this experiment. It must also create the initial iterate. Setting-up the prox operators +++++++++++++++++++++++++++++ Each experiment class needs to specify which prox operators are required. In this case, some of the work has already been done in the parent class, based on the given :attr:`constraint` parameter. .. code-block:: python :linenos: def setupProxOperators(self): """ Determine the prox operators to be used for this experiment """ super(CDI_Experiment, self).setupProxOperators() # call parent's method self.propagator = 'Propagator_FreFra' self.inverse_propagator = 'InvPropagator_FreFra' # remark: self.farfield is always true (set in data processor) # self.proxOperators already contains a prox operator at slot 0. # Here, we add the second one. if self.constraint == 'phaselift': self.proxOperators.append('P_Rank1') elif self.constraint == 'phaselift2': self.proxOperators.append('P_rank1_SR') else: self.proxOperators.append('Approx_Pphase_FreFra_Poisson') self.nProx = self.sets The :meth:`setupProxOperators` method must fill the attribute :attr:`proxOperators` which contains the prox operators that are needed for this experiment. It is a list of strings where each string corresponds to the name of a prox operator class. Based on these names, the actual prox operator instances will be automatically created later on during the initialization process. The attribute :attr:`nProx` indicates how many prox operators are used. Setting the :attr:`propagator` and :attr:`inverse_propagator` attributes is required by the :class:`Approx_Pphase_FreFra_Poisson` prox operator. The following code comes from the :class:`JWST_Experiment` class and gives another example which takes into account a product space formulation. .. code-block:: python :linenos: def setupProxOperators(self): """ Determine the prox operators to be used for this experiment """ super(JWST_Experiment, self).setupProxOperators() # call parent's method self.proxOperators = [] self.productProxOperators = [] if self.formulation == 'cyclic': # there are as many prox operators as there are sets self.nProx = self.sets self.product_space_dimension = 1 for _j in range(self.nProx-1): self.proxOperators.append('Approx_Pphase_FreFra_Poisson') self.proxOperators.append('P_amp_support') else: # product space formulation # add prox operators self.nProx = 2 self.proxOperators.append('P_diag') self.proxOperators.append('Prox_product_space') # add product prox operators self.n_product_Prox = self.product_space_dimension for _j in range(self.n_product_Prox-1): self.productProxOperators.append('Approx_Pphase_FreFra_Poisson') self.productProxOperators.append('P_amp_support') # add propagator and inverse propagator # used by Approx_Pphase_FreFra_Poisson prox operator self.propagator = 'Propagator_FreFra' self.inverse_propagator = 'InvPropagator_FreFra' In the case of a product space formulation, the attribute :attr:`productProxOperators` must also be filled. The attribute :attr:`n_product_Prox` indicates how many product prox operators are used. Defining the default parameters +++++++++++++++++++++++++++++++ Any experiment class must provide default parameters by defining the static method :meth:`getDefaultParameters` .. code-block:: python :linenos: @staticmethod def getDefaultParameters(): defaultParams = { 'experiment_name': 'CDI', 'object': 'nonnegative', 'constraint': 'nonnegative and support', 'Nx': 128, 'Ny': 128, 'Nz': 1, 'sets': 10, 'farfield': True, 'MAXIT': 6000, 'TOL': 1e-8, 'lambda_0': 0.5, 'lambda_max': 0.50, 'lambda_switch': 30, 'data_ball': .999826e-30, 'diagnostic': True, 'iterate_monitor_name': 'FeasibilityIterateMonitor', 'rotate': False, 'verbose': 0, 'graphics': 1, 'anim': False, 'debug': True } return defaultParams Generating the graphical output +++++++++++++++++++++++++++++++ By implementing the :meth:`show` method, each experiment can generate a graphical output from the obtained solution. In the case of the :class:`CDI_Experiment` class most of the work is already done in the parent class, which provides a common implementation for all the phase retrieval experiments. Here, the :meth:`show` method is overridden to display the initial far field data and support constraint. .. code-block:: python :linenos: def show(self): """ Generate graphical output from the solution """ # display plot of far field data and support constraint # figure(123) f, (ax1, ax2) = subplots(1, 2, figsize=(self.figure_width, self.figure_height), dpi=self.figure_dpi) im = ax1.imshow(log10(self.dp + 1e-15)) f.colorbar(im, ax=ax1, fraction=0.046, pad=0.04) ax1.set_title('Far field data') im = ax2.imshow(self.abs_illumination) ax2.set_title('Support constraint') plt.subplots_adjust(wspace=0.3) # adjust horizontal space (width) # between subplots (default = 0.2) f.suptitle('CDI Data') # call parent to display the other plots super(CDI_Experiment, self).show()