import numpy as np
from numpy import conj, dot, empty, ones, sqrt, sum, zeros, exp, \
nonzero, log, tile, shape, real, zeros_like, pi
# from pyfftw.interfaces.scipy_fftpack import fft2, ifft2
from numpy.fft import fft2, ifft2, fft, ifft , fftshift, ifftshift
# TODO: switch to scipy.fftpack? (is faster supposedly)
[docs]class ProxOperator:
"""
Generic interface for prox operators
"""
def __init__(self, experiment):
"""
Initialization method for a concrete instance
Parameters
----------
experiment : instance of Experiment class
Experiment object that will use this prox operaror
"""
pass # base class does nothing
[docs] def eval(self, u, prox_idx=None):
"""
Applies a prox operator to some input data
Parameters
----------
u : ndarray or a list of ndarray objects
Input data to be projected
prox_idx : int, optional
Index of this prox operator
Returns
-------
ndarray or a list of ndarray objects
Result of the application of the prox operator onto
the input data
"""
raise NotImplementedError("This is just an abstract interface")
# @profile
[docs]def magproj(constr, u):
"""
Projection operator onto a magnitude constraint.
Inexact, but stable implementation of magnitude projection.
See LukeBurkeLyon, SIREV 2002.
Based on Matlab code written by Russell Luke (Inst. Fuer
Numerische und Angewandte Mathematik, Universitaet
Gottingen) on July 24, 2001.
Parameters
----------
constr : array_like
A nonnegative array that is the magnitude constraint.
u : array_like
The function to be projected onto constr (can be complex).
Returns
-------
array_like
The projection.
"""
# naive implementation: should eval now, roughly as fast as below
# mod_u = sqrt(u.real**2+u.imag**2)
# with np.errstate(divide='ignore', invalid='ignore'):
# proj = constr/mod_u
# proj[np.isnan(proj)] = 0 #then mod_u=0 and constr=0
# proj = proj*u
# index_inf = np.isinf(proj)
# proj[index_inf] = constr[index_inf] #then mod_u=0 and constr!=0
# return proj """
eps = 3e-20
modsq_u = u.real ** 2 + u.imag ** 2
# beaware: for U * conj(U) subsequent calculations
# are much slower since complex (more than double
# computation time)
denom = modsq_u + eps
denom2 = sqrt(denom)
r_eps = (modsq_u / denom2) - constr
dr_eps = (denom + eps) / (denom * denom2)
return (1 - (dr_eps * r_eps)) * u