from numpy import pi, sqrt, conj
from numpy.fft import fft2, ifft2, fft, ifft
from proxtoolbox.utils.OrbitalTomog import shifted_ifft, shifted_fft
__all__ = ['PropagatorFFTn', 'InvPropagatorFFTn',
'PropagatorFFT2', 'InvPropagatorFFT2',
'Propagator_FreFra', 'InvPropagator_FreFra']
[docs]class PropagatorFFTn:
"""
Basic propagator which uses the n-dimensional fourier transform, transforming over all axes by default
"""
def __init__(self, experiment):
"""
Intialize propagator
:param experiment: experiment class, can be used to pass options as attributes of the class
"""
self.axes = None # set the axes over which to transform, None means all.
[docs] def eval(self, u, **kwargs):
if kwargs:
raise NotImplementedError("Handling of keyword arguments is not yet supported")
return shifted_fft(u, axes=self.axes)
[docs]class InvPropagatorFFTn(PropagatorFFTn):
"""
Basic inverse propagator which uses the inverse n-dimensional fourier transform,
transforming over all axes by default
"""
[docs] def eval(self, u, **kwargs):
if kwargs:
raise NotImplementedError("Handling of keyword arguments is not yet supported")
return shifted_ifft(u, axes=self.axes)
[docs]class PropagatorFFT2(PropagatorFFTn):
"""
Basic propagator which uses the two-dimensional fourier transform
"""
def __init__(self, experiment):
"""
Intialize propagator
:param experiment: experiment class, can be used to pass options as attributes of the class
"""
self.axes = (-2, -1)
[docs]class InvPropagatorFFT2(PropagatorFFT2):
"""
Basic inverse propagator which uses the inverse two-dimensional fourier transform
"""
[docs] def eval(self, u, **kwargs):
if kwargs:
raise NotImplementedError("Handling of keyword arguments is not yet supported")
return shifted_ifft(u, axes=self.axes)
class Propagator_FreFra_Base:
"""
Base class for Propagator_FreFra and InvPropagator_FreFra
"""
def __init__(self, experiment):
if hasattr(experiment, 'FT_conv_kernel'):
self.FT_conv_kernel = experiment.FT_conv_kernel
else:
self.FT_conv_kernel = None
if hasattr(experiment, 'fresnel_nr'):
self.fresnel_nr = experiment.fresnel_nr
else:
self.fresnel_nr = None
self.farfield = experiment.farfield
self.Nx = experiment.Nx
self.Ny = experiment.Ny
if hasattr(experiment, 'illumination'):
self.illumination = experiment.illumination
else:
self.illumination = None
if hasattr(experiment, 'magn'):
self.magn = experiment.magn
else:
self.magn = None
if hasattr(experiment, 'beam'):
self.beam = experiment.beam
else:
self.beam = None
self.data_sq = experiment.data_sq
[docs]class Propagator_FreFra(Propagator_FreFra_Base):
"""
Propagator for near field or far field Fourier measurements.
Based on Matlab code written by Russell Luke (Inst. Fuer
Numerische und Angewandte Mathematik, Universitaet
Gottingen) around Jan 23, 2019.
"""
def __init__(self, experiment):
super(Propagator_FreFra, self).__init__(experiment)
[docs] def eval(self, u, prox_idx=None):
"""
Propagation function for near field or far field Fourier
measurements.
Parameters
----------
u : array_like
Function in the physical domain to be projected
prox_idx : int, optional
Index of the prox operator calling this method
Returns
-------
u_hat : array_like
the propagated field at the measurement plane
"""
if prox_idx is None:
j = 0
else:
if self.FT_conv_kernel is not None:
j = min(prox_idx, len(self.FT_conv_kernel)-1)
elif self.fresnel_nr is not None and self.fresnel_nr[j] > 0:
j = min(prox_idx, len(self.fresnel_nr)-1)
else:
j = 0
m = u.shape[0]
if u.ndim > 1:
n = u.shape[1]
else:
n = 1
if m > 1 and n > 1:
FFT = lambda u: fft2(u)
IFFT = lambda u: ifft2(u)
else:
FFT = lambda u: fft(u)
IFFT = lambda u: ifft(u)
if self.farfield:
if self.fresnel_nr is not None and self.fresnel_nr[j] > 0:
illumination_j = self.illumination[j]
u_hat = -1j*self.fresnel_nr[j]/(self.Nx*self.Ny*2*pi)*FFT(u-illumination_j) + self.FT_conv_kernel[j]
elif self.FT_conv_kernel is not None:
u_hat = FFT(self.FT_conv_kernel[j]*u)/(self.Nx*self.Ny)
else:
u_hat = FFT(u)/sqrt(self.Nx*self.Ny)
else: # near field
if self.beam is not None:
u_hat = IFFT(self.FT_conv_kernel[j]*FFT(u*self.beam[j]))/self.magn
else:
u_hat = IFFT(self.FT_conv_kernel[j]*FFT(u))/self.magn
return u_hat
[docs]class InvPropagator_FreFra(Propagator_FreFra_Base):
"""
Inverse propagator for near field or far field Fourier measurements
Based on Matlab code written by Russell Luke (Inst. Fuer
Numerische und Angewandte Mathematik, Universitaet
Gottingen) around Jan 23, 2019.
"""
def __init__(self, experiment):
super(InvPropagator_FreFra, self).__init__(experiment)
[docs] def eval(self, p_Mhat, prox_idx = None):
"""
Inverse propagation function for near field or far field Fourier
measurements.
Parameters
----------
p_Mhat : array_like
Function in the measurement plane to be propagated
back to the physical plane
prox_idx : int, optional
Index of the prox operator calling this method
Returns
-------
u_new : array_like
Propagated field at the object plane
"""
if prox_idx is None:
j = 0
else:
j = min(prox_idx, len(self.data_sq)-1)
m = p_Mhat.shape[0]
if p_Mhat.ndim > 1:
n = p_Mhat.shape[1]
else:
n = 1
if m > 1 and n > 1:
FFT = lambda u: fft2(u)
IFFT = lambda u: ifft2(u)
else:
FFT = lambda u: fft(u)
IFFT = lambda u: ifft(u)
if self.farfield:
if self.fresnel_nr is not None and self.fresnel_nr[j] > 0:
u_new = (self.Nx*self.Ny*2*pi)*IFFT(p_Mhat) / self.fresnel_nr[j]
elif self.FT_conv_kernel is not None:
u_new = (conj(self.FT_conv_kernel[j])*IFFT(p_Mhat)) * self.Nx*self.Ny
else:
u_new = IFFT(p_Mhat)*sqrt(self.Nx*self.Ny)
else: # near field
if self.beam is not None:
u_new = IFFT(conj(self.FT_conv_kernel[j])*FFT(p_Mhat*self.magn)) / self.beam[j]
else:
u_new = IFFT(conj(self.FT_conv_kernel[j])*FFT(p_Mhat*self.magn))
return u_new