from numpy import roll, ndarray, floor, iscomplexobj, round, any, isnan, nan_to_num
from scipy.ndimage.measurements import maximum_position, center_of_mass
from scipy.fftpack import fftn, fftshift, ifftn, ifftshift
from warnings import warn
from numpy.lib.stride_tricks import as_strided
__all__ = ["shift_array", 'roll_to_pos', 'shifted_ifft', 'shifted_fft', 'tile_array']
[docs]def shift_array(arr: ndarray, dy: int, dx: int):
"""
Use numpy.roll to shift an array in the first and second dimensions
:param arr: numpy array
:param dy: shift in first dimension
:param dx: shift in second dimension
:return: array like arr
"""
temp = roll(arr, (dy, dx), (0, 1))
return temp
[docs]def roll_to_pos(arr: ndarray, y: int = 0, x: int = 0, pos: tuple = None, move_maximum: bool = False,
by_abs_val: bool = True) -> ndarray:
"""
Shift the center of mass of an array to the given position by cyclic permutation
:param arr: 2d array, works best for well-centered feature with limited support
:param y: position parameter
:param x: position parameter for second dimension
:param pos: tuple with the new position, overriding y,x values. should be used for higher-dimensional arrays
:param move_maximum: if true, look only at max-value
:param by_abs_val: take abs value for the determination of max-val or center-of-mass
:return: array like original
"""
if move_maximum:
if by_abs_val or iscomplexobj(arr):
old = floor(maximum_position(abs(arr)))
else:
old = floor(maximum_position(arr))
else:
if by_abs_val or iscomplexobj(arr):
old = floor(center_of_mass(abs(arr)))
else:
old = floor(center_of_mass(arr))
if any(isnan(old)):
old = nan_to_num(old)
warn(Warning("Unexpected error in the calculation of the center of mass, casting NaNs to num"))
if pos is not None: # dimension-independent method
shifts = tuple([int(round(pos[i] - old[i])) for i in range(len(pos))])
dims = tuple([i for i in range(len(pos))])
temp = roll(arr, shift=shifts, axis=dims)
else: # old method
temp = shift_array(arr, int(y - old[0]), int(x - old[1]))
if temp.shape != arr.shape:
raise Exception('Non-matching input and output shapes')
return temp
[docs]def shifted_fft(arr, axes=None):
"""
Combined fftshift and fft routine, based on scipy.fftpack
Args:
arr: numpy array
axes: identical to argument for scipy.fftpack.fft
Returns:
transformed array
"""
return ifftshift(fftn(fftshift(arr, axes=axes), axes=axes), axes=axes)
[docs]def shifted_ifft(arr, axes=None):
"""
Combined fftshift and fft routine, based on scipy.fftpack
Args:
arr: numpy array
axes: identical to argument for scipy.fftpack.fft
Returns:
transformed array
"""
return fftshift(ifftn(ifftshift(arr, axes=axes), axes=axes), axes=axes)
[docs]def tile_array(a: ndarray, shape, normalize: bool = False):
"""
Upsample an array by nearest-neighbour interpolation, i.e. [1,2] -> [1,1,2,2]
:param a: numpy array, ndim = [2,3]
:param shape: tile size, single integer for rectangular tiles, tuple for individual axes otherwise
:return: resampled array
"""
if a.ndim == 2:
try:
b0, b1 = shape
except TypeError:
b0 = shape
b1 = shape
if normalize:
norm = (b0 * b1)
else:
norm = 1
r, c = a.shape # number of rows/columns
rs, cs = a.strides # row/column strides
x = as_strided(a, (r, b0, c, b1), (rs, 0, cs, 0)) # view a as larger 4D array
return x.reshape(r * b0, c * b1)/norm # create new 2D array
elif a.ndim == 3:
try:
b0, b1, b2 = shape
except TypeError:
b0 = shape
b1 = shape
b2 = shape
if normalize:
norm = (b0*b1*b2)
else:
norm = 1
x, y, z = a.shape
xs, ys, zs = a.strides
temp = as_strided(a, (x, b0, y, b1, z, b2), (xs, 0, ys, 0, zs, 0))
return temp.reshape((x * b0, y * b1, z * b2))/norm
else:
raise NotImplementedError("Arrays of dimensions other than 2 and 3 are not implemented yet")