{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Constrained Tikhonov regularization implemented by (accelerated) forward-backward splitting\n", "We consider the two-dimensional deconvolution problems to find a non-negative function f given data \n", "$$\n", " d \\sim \\mathrm{Pois}(h*f)\n", "$$\n", "with a non-negative convolution kernel $h$, and $\\mathrm{Pois}$ denotes the element-wise Poisson distribution.\n", "\n", "We explore the use of the semismooth Newton method to implement constrained Tikhonov regularization \n", "$$\n", "\\hat{f} = \\mathrm{argmin}_{f\\geq 0} \\left[\\| h*f-d\\|^2_{L^2(w)} + \\alpha \\|f\\|^2_{L^2}\\right]\n", "$$\n", "with a weight $w = \\frac{1}{\\sqrt{d+1}}$. The regularization parameter $\\alpha$ is chosen such that the duality gap is small." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from regpy.operators.convolution import GaussianBlur\n", "from regpy.vecsps import UniformGridFcts\n", "from regpy.solvers import TikhonovRegularizationSetting\n", "from regpy.solvers.linear.proximal_gradient import ForwardBackwardSplitting, FISTA\n", "from regpy.hilbert import L2, HmDomain\n", "from regpy.stoprules import CountIterations, DualityGapStopping" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Logging\n", "\n", "For the logging of each subroutine we will rely on the loglevel INFO to obtain certain information and predefine a format for the logging. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import logging\n", "\n", "logging.basicConfig(\n", " level=logging.INFO,\n", " format='%(asctime)s %(levelname)s %(name)-20s :: %(message)s'\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting routine \n", "\n", "To be able to later plot easily and consistently we define a routine to plot the images using `matplotlib.pyplot`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import matplotlib as mplib\n", "\n", "def comparison_plot(grid,truth,reco,title_right='exact',title_left='reconstruction',residual=None):\n", " plt.rcParams.update({'font.size': 22})\n", " extent = [grid.axes[0][0],grid.axes[0][-1], grid.axes[1][0], grid.axes[1][-1]]\n", " maxval = np.max(truth[:]); minval = np.min(truth[:])\n", " mycmap = mplib.colormaps['hot']\n", " mycmap.set_over((0,0,1.,1.)) # use blue as access color for large values\n", " mycmap.set_under((0,1,0,1.)) # use green as access color for small values\n", " if not (residual is None):\n", " fig, ((ax1,ax2),(ax3,ax4)) = plt.subplots(2,2,figsize = (22,16))\n", " else:\n", " fig, (ax1,ax2) = plt.subplots(1,2,figsize = (22,8))\n", " im1= ax1.imshow(reco.T,extent=extent,origin='lower',\n", " vmin=1.05*minval-0.05*maxval, vmax =1.05*maxval-0.05*minval,\n", " cmap=mycmap\n", " )\n", " ax1.title.set_text(title_left)\n", " fig.colorbar(im1,extend='both')\n", " im2= ax2.imshow(truth.T,extent=extent, origin='lower',\n", " vmin=1.05*minval-0.05*maxval, vmax =1.05*maxval-0.05*minval,\n", " cmap=mycmap\n", " )\n", " ax2.title.set_text(title_right)\n", " fig.colorbar(im2,extend='both',orientation='vertical')\n", " if not (residual is None):\n", " maxv = np.max(reco[:]-truth[:])\n", " im3 = ax3.imshow(truth.T-reco.T,extent=extent, origin='lower',vmin= -maxv,vmax=maxv,cmap='RdYlBu_r')\n", " ax3.title.set_text('reconstruction error')\n", " fig.colorbar(im3)\n", "\n", " maxv = np.max(residual[:])\n", " im4 = ax4.imshow(residual.T,extent=extent, origin='lower',vmin= -maxv,vmax=maxv,cmap='RdYlBu_r')\n", " ax4.title.set_text('data residual')\n", " fig.colorbar(im4)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test object\n", "\n", "The exact solution is constructed by a ring a box and an cross above and some decreasing bubbles on the bottom." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "grid = UniformGridFcts((-1, 1, 256), (-1.5, 1, 256),dtype = float, periodic = True)\n", "\"\"\"Space of real-valued functions on a uniform grid with rectangular pixels\"\"\"\n", "X = grid.coords[0]; Y = grid.coords[1]\n", "\"\"\"x and y coordinates.\"\"\"\n", "cross = 1.0*np.logical_or((abs(X)<0.01) * (abs(Y)<0.3),(abs(X)<0.3) * (abs(Y)<0.01)) \n", "rad = np.sqrt(X**2 + Y**2)\n", "ring = 1.0*np.logical_and(rad>=0.9, rad<=0.95)\n", "smallbox = (abs(X+0.55)<=0.05) * (abs(Y-0.55)<=0.05)\n", "bubbles = (1.001+np.sin(50/(X+1.3)))*np.exp(-((Y+1.25)/0.1)**2)*(X>-0.8)*(X<0.8)\n", "\n", "ramp = Y<=-1\n", "\n", "objects = 200*(ring + 2.0*cross + 1.5*smallbox + 2*ramp -bubbles)\n", "exact_sol = objects \n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## creating synthetic data\n", "\n", "Using the `np.random.poisson` we construct poisson data from a Gaussian blur of the true solution. The Gaussian blur is constructed with the `GaussianBlur` convolution operator of `RegPy`." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "a=0.15\n", "conv = GaussianBlur(grid,a,pad_amount=((16,16),(16,16)))\n", "\"\"\"Convolution operator \\(f\\mapsto h*f\\) for the convolution kernel \\(h(x)=\\exp(-|x|_2^2/a^2)\\).\"\"\"\n", "blur = conv(exact_sol)\n", "blur[blur<0] = 0.\n", "\"\"\"Simulated exact data.\"\"\"\n", "data = np.random.poisson(blur)\n", "\"\"\"Simulated measured data. The Poisson distribution occurs if photon count detectors are used.\"\"\"\n", "comparison_plot(grid,exact_sol,data,title_left='noisy measurement data')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reconstruction by Forward-Backward Splitting \n", "We split the Tikhonov functional into the quadratic data fidelity term and a penalty term, which also includes the nonnegativity constraint. Then we setup a solver using the forward-backward splitting and FISTA to solve the Tikhonov functional.\n", "\n", "Note that we allow in the quadratic bilateral constraint small violations of the bounds to prevent the evaluation to be infinite. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from regpy.functionals import QuadraticBilateralConstraints\n", "\n", "weighted_data_space = L2(grid, weights = 1./(1.+data))\n", "pen = QuadraticBilateralConstraints(grid,lb=0,ub=400*conv.domain.ones(),x0=0,eps=0.0001)\n", "alpha = 0.001\n", "setting = TikhonovRegularizationSetting(op=conv, penalty=pen, data_fid = weighted_data_space,data_fid_shift=data,regpar=alpha)\n", "\n", "solver_FB = ForwardBackwardSplitting(setting)\n", "solver_FISTA = FISTA(setting)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Solving the inverse problem\n", "\n", "Using a stopping rule defined using the measure of the duality gap we solve run the solver." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from regpy.stoprules import DualityGapStopping\n", "\n", "stoprule=DualityGapStopping(solver_FB,threshold = 1., max_iter=1500,logging_level=logging.WARNING)\n", "x,y = solver_FB.run(stoprule)\n", "comparison_plot(grid,exact_sol,x)\n", "stoprule=DualityGapStopping(solver_FISTA,threshold = 1., max_iter=150,logging_level=logging.WARNING)\n", "x,y = solver_FISTA.run(stoprule)\n", "comparison_plot(grid,exact_sol,x)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reconstruction with one sided constraints\n", "\n", "Other than using a bilateral constraint we may only use a lower bound quadratic constraint. Then we have to redefine the setting and define new solvers and then invert. " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from regpy.functionals import QuadraticLowerBound\n", "weighted_data_space = L2(grid, weights = 1./(1.+data))\n", "alpha = 1e-3\n", "pen = QuadraticLowerBound(grid,grid.zeros(), 50*grid.ones())\n", "setting = TikhonovRegularizationSetting(op=conv, penalty=pen, data_fid = L2,data_fid_shift=data,regpar=alpha)\n", "solver_FB = ForwardBackwardSplitting(setting,grid.zeros())\n", "solver_FISTA = FISTA(setting)\n", "\n", "stoprule=DualityGapStopping(solver_FB,threshold = 1., max_iter=2500,logging_level=logging.WARNING)\n", "x,y = solver_FB.run(stoprule)\n", "comparison_plot(grid,exact_sol,x)\n", "stoprule=DualityGapStopping(solver_FISTA,threshold = 1., max_iter=200,logging_level=logging.WARNING)\n", "x,y = solver_FISTA.run(stoprule)\n", "comparison_plot(grid,exact_sol,x)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reconstruction without constraints\n", "\n", "To compare to the previous reconstructions we use a standard $L^2$ penalty and only after the reconstruction we use a cut to compare the results." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "weighted_data_space = L2(grid, weights = 1./(1.+data))\n", "alpha = 1e-3\n", "pen = L2\n", "setting = TikhonovRegularizationSetting(op=conv, penalty=pen, data_fid = L2,data_fid_shift=data,regpar=alpha)\n", "setting.penalty\n", "solver_FB = ForwardBackwardSplitting(setting,grid.zeros())\n", "solver_FISTA = FISTA(setting)\n", "\n", "stoprule=DualityGapStopping(solver_FB,threshold = 1., max_iter=2500,logging_level=logging.WARNING)\n", "x,y = solver_FB.run(stoprule)\n", "x=np.maximum(x,0)\n", "comparison_plot(grid,exact_sol,x)\n", "stoprule=DualityGapStopping(solver_FISTA,threshold = 1., max_iter=200,logging_level=logging.WARNING)\n", "x,y = solver_FISTA.run(stoprule)\n", "x=np.maximum(x,0)\n", "comparison_plot(grid,exact_sol,x)" ] } ], "metadata": { "kernelspec": { "display_name": "3.12.3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 2 }