Phase Retrieval II: Iterative Transform Algorithms

Introduction

In a previous post I motivated phase retrieval for wavefront sensing, and gave a cursory comparison between it and more typical wavefront measurement schemes. In this post, I will elaborate upon the “Iterative Transform” family of algorithms for performing phase retrieval, focusing on the ones most relevant to image based wavefront sensing. These algorithms come in two varieties: two and single image. The two image algorithms use a known PSF and known pupil plane image, while the single image algorithms use a measured PSF, and have general constraints on the pupil plane, for example that the pupil have a certain support.

I will use the same syntax here as Fienup’s famous 1982 paper, which is the most cited in Applied Optics to this day.

Iterative transform algorithms are all structured in the same way:

  1. Compile a set of constraints in the “object” (pupil) domain and in the “Fourier” (PSF) domain
  2. Propagate the object to the Fourier domain
  3. Apply Fourier constraints
  4. Back-propagate to the object domain, apply object domain constraints

From an optimization background, the Fourier constraints are often “exact” and the object constraints may be “inequality” or otherwise weak constraints. The weakest form of object domain constraints come from “single plane” Phase Retrieval algorithms, which impart only cursory knowledge of the object domain in the form of a support constraint.

Fienup uses the following variables, which we will mimic:

ObjectFourier
Truth$f$$F$
Estimated$g$$G$

The “codex” to this is that $g$ is an estimate of $f$, and capitalized variables are Fourier transforms of their lowercase counterparts. This convention comes from an image reconstruction background, where $g$ would be colloquially an image. In that field, $f$ is usually assumed real and positive. For image-based wavefront sensing, all quantities are complex and the phase of $f$ is the desired measurement.

In this post, we will code and demonstrate several iterative transform algorithms. Each is implemented in PRAISE and copied here for the reader to see their simplicity. PRAISE itself is a teaching repository, and does not implement many needed tools for modern, high quality and robust image based wavefront sensing like algorithmic differentiation (for computational speed) or incorporation of diversity (necessary to remove sign ambiguities and in general enhance the uniqueness of the solution).

Gerchberg-Saxton, Usage of Iterative Transform Algorithms

Arguably the most common iterative transform algorithm is the Gerchberg-Saxton algorithm, which is implemented as follows:

  1. Transform $g \rightarrow G$
  2. Compute $G’ = |F| \exp( i\Theta[G])$
  3. Transform $G’ \rightarrow g'$
  4. Compute $g’’ = |f| \exp( i\Theta[g’])$

This procedure is iterated (feeding $g’’ \rightarrow g$) to termination. The termination criteria can be formed any number of ways, which we will elaborate upon later in this article. $\Theta[x]$ is the angle function, which returns the phase of a complex number. The quantities $|F|$ and $|f|$ are usually measured. You can think of them with the names:

$$ \begin{align} |F| &\equiv \sqrt{\text{measured PSF}} \\
|f| &\equiv \sqrt{\text{measured pupil-plane image}} \end{align} $$

The square roots are because we measure intensities, but the quantities $F$ and $f$ are “complex E fields.”

A reasonable question is how to initialize $g$. If and only if there is a unique solution (and the algorithm allowed to iterate long enough to find it), then it would not matter. This is weakly true, so it is not of grave consequence in most cases. It is relatively conventional to begin from a zero phase $g$, or a random phase $g$, if a reasonable guess for $\Theta[f]$ is not known.

An implementation of this in python is as simple as:

def _init_iterative_transform(self, psf, pupil_amplitude, phase_guess=None):
    # python teaching moment -- self as an argument name is only a convention,
    # and carries no special meaning.  This function exists to refactor the
    # various iterative transform types without subclassing or inheritance
    if phase_guess is None:
        phase_guess = np.random.rand(*pupil_amplitude.shape)

    absF = np.sqrt(psf)
    absg = pupil_amplitude

    self.absF = fft.ifftshift(absF)
    self.absg = fft.ifftshift(absg)
    phase_guess = fft.ifftshift(phase_guess)

    self.g = self.absg * np.exp(1j*phase_guess)
    self.mse_denom = np.sum((self.absF)**2)
    self.iter = 0
    self.costF = []


def _mean_square_error(a, b, norm=1):
    diff = a - b
    mse = np.sum(diff**2)
    return mse / norm


class GerchbergSaxton:

    def __init__(self, psf, pupil_amplitude, phase_guess=None):
        _init_iterative_transform(self, psf, pupil_amplitude, phase_guess)

    def step(self):
        """Advance the algorithm one iteration."""
        G = fft.fft2(self.g)
        mse = _mean_square_error(abs(G), self.absF, self.mse_denom)

        phs_G = np.angle(G)
        Gprime = self.absF * np.exp(1j*phs_G)
        gprime = fft.ifft2(Gprime)
        phs_gprime = np.angle(gprime)
        gprimeprime = self.absg * np.exp(1j*phs_gprime)

        self.costF.append(mse)
        self.iter += 1
        self.g = gprimeprime
        return gprimeprime

The steps 1..4 above are implemented in the step method. The usage of this algorithm is, loosely, as follows:

psf = ... # load measured data
absg = ... # load measured pupil data, or model of transmission
phs_guess = np.zeros_like(absg)
gsp = praise.GerchbergSaxton(psf, absg, phs_guess)
for i in range(1000):
    g = gsp.step()

If you read the implementation carefully, you may notice there is this mse and costF business, what is that about? Notice that in enumerating steps 1..4, nowhere did there appear an apparent good way to determine when the algorithm is “done” - only perhaps that the 1,000 iterations shown in this example is enough. mse is short for Mean Square Error, and is a maximum likelihood measure between $G$ and $F$ if there is additive gaussian noise. The code here is written with the inputs pre-shifted, so that we can avoid repeatedly fftshift and ifftshifting inside the hot loop. You could write the code with a more complex termination criteria, for example:

gsp = praise.GerchbergSaxton(psf, absg, phs_guess)
while gsp.iter < 10_000 and gsp.costF[-1] > 0.001:
    g = gsp.step()

This will trigger termination in a maximum of 10,000 iterations, or a cost of 1/1000, whichever comes first. Still a more advanced way to write the usage would be:

gsp = praise.GerchbergSaxton(psf, absg, phs_guess)
MAX_ITER = 10_000
MIN_COST = 1e-3
MIN_COST_DELTA = 1e-4
last_cost = np.inf
while True:
    g = gsp.step()
    cost = g.cost[-1]
    delta_cost = cost - last_cost
    if g.iter == MAX_ITER:
        msg = 'terminated at maximum iteration count'
        break
    if cost <= MIN_COST:
        msg = 'terminated at sufficiently small cost'
        break
    if delta_cost <= MIN_COST_DELTA:
        msg = 'terminated due to stagnated cost'
        break

    last_cost = cost

While a large block for a blog post, there is very little logic here, and it could be easily wrapped in a helper function. This allows the algorithm to stop for any number of reasonable reasons, as soon as possible. If the algorithm is being used in real time or quasi real time, fast termination has great utility since more sensing (and control) iterations can be done per unit time. This allows the process to conclude faster (time is money), or rejection of faster dynamics in the system (higher ultimate performance).

Returning to the GS algorithm in particular, note that the only ‘forcing’ function is the Fourier domain constraint: Gprime = self.absF * np.exp(1j*phs_G). One interpretation of this algorithm is that this line of code ‘pushes’ the algorithm forward on the step, and the satisfaction of the object domain constraint acts as a more typical constraint. If we assume that $g_0$ satisfied the object domain constraint (almost certainly true), the object domain constraint only kicks in if the Fourier domain one pushed to a solution in violation of it.

It is not self evident that this is the optimum “force,” and indeed it is not if we look at the next few algorithms. I will make a special note that this algorithm was not created for image based wavefront sensing, and “happens to be” appropriate for this task. Given Gerchberg Saxton as the original article, Error Reduction is the second iterative transform algorithm.

Error Reduction

Error reduction is the same as Gerchberg-Saxton for steps 1..3, differing at step 4. In Error Reduction, the “minimum update” is made. For image based wavefront sensing, this simply means enforcing that $g$ be positive within the support, and zero outside the support. This permits $g$ to evolve away from the measured value, if a measured value is present (the two image version of image-based wavefront sensing). This is not particularly interesting for the two image problem, but is very useful when the pupil amplitude it not known. In such a circumstance, error reduction can be used to estimate amplitude and phase simultaneously.

class ErrorReduction:

    def __init__(self, psf, pupil_amplitude, phase_guess=None):
        _init_iterative_transform(self, psf, pupil_amplitude, phase_guess)
        self.mask = self.absg > 1e-6
        self.invmask = ~self.mask

    def step(self):
        G = fft.fft2(self.g)
        mse = _mean_square_error(abs(G), self.absF, self.mse_denom)

        phs_G = np.angle(G)
        Gprime = self.absF * np.exp(1j*phs_G)
        gprime = fft.ifft2(Gprime)
        # error reduction uses a "minimum update"
        # for G -> G', in this case we use the most
        # common flavor of that; enforce the |F| constraint

        # now apply the ER object domain constraints:
        # positive g, support
        gprimeprime = gprime
        gprimeprime[self.invmask] = 0  # support constraint
        subset = gprimeprime[self.mask]
        subset[subset < 0] = 0  # positivity constraint

        self.costF.append(mse)
        self.iter += 1
        self.g = gprimeprime
        return gprimeprime

Steepest Descent

As mentioned in discussion of the Gerchberg-Saxton algorithm, it is not self evident that GS iterations are the optimum approach (~= fastest, guaranteed to reach a minimum error). Because each estimate $g''$ of GS satisfies both domain’s constraints, it tends to be monotonic (if the iterations are indeed a “contractive mapping” then it is monotonic, but it is not guaranteeable that the iterations are a contractive mapping).

If we think again of image based wavefront sensing as an optimization problem, one of the colloquial optimization routines is gradient descent: follow the steepest path downhill. In this formulation, we find the steepest descent routine. In Fienup1982, he showed that this update is formed as:

$$ \begin{align} B &= \sum \left| |F| - |G| \right|^2 \\
g’’ - g &= -\frac{1}{4}\partial_gB = \frac{1}{2}[g’ - g] \\
g’’ &= -\frac{1}{2}[g-g’] \end{align} $$ The prefix of $1/2$ is called a “full length step.” Since the cost function $B$ is quadratic in $g’-g$ (the complex Fourier transform is approximately a linear operator for small phases), the full length step may under-predict the “best” step, so the factor of one half can be removed. In PRAISE, the doublestep kwarg to the SteepestDescent class’ constructor controls whether this is done or not. This could be generalized by replacing the factor of $1/2$ or 1 with a gain constant.

class SteepestDescent:

    def __init__(self, psf, pupil_amplitude, phase_guess=None, doublestep=True):
        _init_iterative_transform(self, psf, pupil_amplitude, phase_guess)

    def step(self):
        G = fft.fft2(self.g)
        mse = _mean_square_error(abs(G), self.absF, self.mse_denom)

        phs_G = np.angle(G)
        Gprime = self.absF * np.exp(1j*phs_G)
        gprime = fft.ifft2(Gprime)

        # steepest descent is the same as GS until the time to form
        # g'' ...
        #  g'' - g = -1/4 partial_g B = 1/2 (g' - g)
        # move g out of the LHS
        # -> g'' = 1/2 (g' - g) + g
        # if doublestep g'' = (g' - g) + g => g'' = g'
        if self.doublestep:
            gprimeprime = gprime
        else:
            gprimeprime = 0.5 * (gprime - self.g) + self.g

        # finally, apply the object domain constraint
        phs_gprime = np.angle(gprime)
        gprimeprime = self.absg * np.exp(1j*phs_gprime)

        self.costF.append(mse)
        self.iter += 1
        self.g = gprimeprime
        return gprimeprime

Even given this, which does drive to deeper minima, faster than Gerchberg-Saxton, have we achieved optimality? The answer is still no – gradient descent is well known to move slowly to minima, even if it does eventually find them.

Conjugate Gradient

Without going into why Conjugate-Gradient forms the update this way (consult Nocedal or Bertsekas’s books on optimization), the update is formed as:

$$ \begin{align} g_k’’ &= g_k + h_k D_k \\
D_k &= g_k - g_k’ + (B_k)/(B_{k-1})D_{k-1} \\
D_0 &= g_0’ - g_0 \end{align} $$

The variable “$D$” is the direction matrix, and is $m x n$ for an $m x n$ array of elements to represent $g$. If $g$ is 512 x 512, this is a 262,000 variable optimization problem! If we did not have the analytic formulation for $D$, its computational cost would be rather horrendous, but with the analytic expression it comes for free. The only important number that is not pre-specified is what $h_k$ is to be. $h$ is simply a gain constant, and subscript $k$ indicates that it has a value for each iteration. While you could use a single constant, in practice this is not the best choice, and the algorithm should be “cooled off” by reducing $h_k$ gradually as it iterates. $h$ should be bounded, loosely, to be between 0 and 1. You can imagine that by setting $h$ to a large value, say a million, the algorithm can be made to diverge (applying a “correction” (update) 1M times as large as the sensed error surely can bring no good). From this intuition, the delta cost can be monitored, and the gain reduced if the cost is moving “significantly” uphill.

Immediately following this derivation, Fienup proposed a different update to g, which is implemented here. “True” CG, as in the above equations, is implemented but commented out.

class ConjugateGradient:
    def __init__(self, psf, pupil_amplitude, phase_guess=None, hk=1):
        _init_iterative_transform(self, psf, pupil_amplitude, phase_guess)
        self.gprimekm1 = self.g
        self.hk = hk

    def step(self):
        G = fft.fft2(self.g)
        mse = _mean_square_error(abs(G), self.absF, self.mse_denom)
        Bk = mse
        phs_G = np.angle(G)
        Gprime = self.absF * np.exp(1j*phs_G)
        gprime = fft.ifft2(Gprime)

        # this is the update described in Fienup1982 Eq. 36
        # if self.iter == 0:
        #     D = gprime - self.g
        # else:
        #     D = (gprime - self.g) + (Bk/self.Bkm1) * self.Dkm1

        # gprimeprime = self.g + self.hk * D

        gprimeprime = gprime + self.hk * (gprime - self.gprimekm1)

        # finally, apply the object domain constraint
        phs_gprime = np.angle(gprimeprime)
        gprimeprime = self.absg * np.exp(1j*phs_gprime)

        self.costF.append(mse)
        self.iter += 1
        self.Bkm1 = Bk  # bkm1 = "B_{k-1}"; B for iter k-1
        # self.Dkm1 = D
        self.gprimekm1 = gprime
        self.g = gprimeprime
        return gprimeprime

Reducing $h_k$ too quickly will lead to artificially slow convergence. For example, if the cost at iteration $k$ were 0.00016 and at $k+1$ it were $0.00017$, you may not want to reduce $h$ yet for an uphill motion that is so small. Introducing lag into the comparison is an interesting approach for solving “when to cool it off;” compare cost[k] to cost[k-N] for some moderate value N, say 5 iterations. Then if the algorithm is diverging (sustained uphill motion), cooling it off will presumably course correct. This should be combined with an eventual cool-off anyway, perhaps based on the value of cost[k] itself; slowing the algorithm when it has a good data agreement.

Input-Input, Output-Output, Hybrid Input-Output

These algorithms are relevant to image reconstruction, and work poorly for image based wavefront sensing, in my opinion.

Misel’s algorithms

Misel’s algorithms are discussed in a later post.

Wrap-up

In this post, we walked through the implementation of several iterative transform type Phase Retrieval algorithms, in the context of image-based wavefront sensing. In the next post, we will discuss in more detail the practical nature of these algorithms.