Source code for proxtoolbox.Algorithms.AP
# -*- coding: utf-8 -*-
"""
Created on Mon Dec 14 13:08:06 2015
@author: rebecca
"""
from math import sqrt
from numpy import zeros
from scipy.linalg import norm
from .algorithms import Algorithm
[docs]class AP(Algorithm):
"""
Alternating Projections
"""
def __init__(self, config):
"""
Parameters
----------
config : dict
Dictionary containing the problem configuration.
It must contain the following mappings:
proj1: ProxOperator
First ProxOperator (the class, no instance)
proj2: ProxOperator
Second ProxOperator (the class, no instance)
beta0: number
Starting relaxation parmater
beta_max: number
Maximum relaxation parameter
beta_switch: int
Iteration at which beta moves from beta0 -> beta_max
normM: number
?
Nx: int
Row-dim of the product space elements
Ny: int
Column-dim of the product space elements
Nz: int
Depth-dim of the product space elements
dim: int
Size of the product space
"""
self.proj1 = config['proj1'](config); self.proj2 = config['proj2'](config);
self.normM = config['normM'];
self.Nx = config['Nx']; self.Ny = config['Ny']; self.Nz = config['Nz'];
self.dim = config['dim'];
self.iters = 0
[docs] def run(self, u, tol, maxiter):
"""
Runs the algorithm for the specified input data
"""
proj1 = self.proj1; proj2 = self.proj2;
normM = self.normM;
iters = self.iters
change = zeros(maxiter+1);
change[0] = 999;
gap = change.copy();
tmp1 = proj2.work(u);
while iters < maxiter and change[iters] >= tol:
iters += 1;
tmp_u = proj1.work(tmp1);
tmp1 = proj2.work(tmp_u);
tmp_change = 0; tmp_gap = 0;
if self.Ny == 1 or self.Nx == 1:
tmp_change = (norm(u-tmp_u,'fro')/normM)**2;
tmp_gap = (norm(tmp1-tmp_u,'fro')/normM)**2;
elif self.Nz == 1:
for j in range(self.dim):
tmp_change += (norm(u[:,:,j]-tmp_u[:,:,j],'fro')/normM)**2;
tmp_gap += (norm(tmp1[:,:,j]-tmp_u[:,:,j])/normM,'fro')**2;
else:
for j in range(self.dim):
for k in range(self.Nz):
tmp_change += (norm(u[:,:,k,j]-tmp_u[:,:,k,j],'fro')/normM)**2;
tmp_gap += (norm(tmp1[:,:,k,j]-tmp_u[:,:,k,j],'fro')/normM)**2;
change[iters] = sqrt(tmp_change);
gap[iters] = sqrt(tmp_gap);
u = tmp_u.copy();
tmp = proj1.work(u);
tmp2 = proj2.work(u);
if self.Ny == 1:
u1 = tmp[:,1];
u2 = tmp2[:,1];
elif self.Nx == 1:
u1 = tmp[1,:];
u2 = tmp2[1,:];
elif self.Nz == 1:
u1 = tmp[:,:,1];
u2 = tmp2[:,:,1];
else:
u1 = tmp;
u2 = tmp2;
change = change[1:iters+1];
gap = gap[1:iters+1];
return u1, u2, iters, change, gap;