# -*- coding: utf-8 -*-
"""
Calculate the tilt based on a reference spectrum with high SNR, e.g. Wavelength calibration image
Authors
-------
Nikolai Piskunov
Ansgar Wehrhahn
Version
--------
0.9 - NP - IDL Version
1.0 - AW - Python Version
License
-------
....
"""
import logging
import matplotlib.pyplot as plt
import numpy as np
from numpy.polynomial.polynomial import polyval2d
from scipy import signal
from scipy.ndimage import gaussian_filter1d, median_filter
from scipy.optimize import least_squares
from tqdm import tqdm
from .extract import fix_parameters
from .util import make_index
from .util import polyfit2d_2 as polyfit2d
logger = logging.getLogger(__name__)
[docs]class ProgressPlot: # pragma: no cover
def __init__(self, ncol, width, title=None):
plt.ion()
fig, (ax1, ax2, ax3) = plt.subplots(ncols=3)
plot_title = "Curvature in each order"
if title is not None:
plot_title = f"{title}\n{plot_title}"
fig.suptitle(plot_title)
(line1,) = ax1.plot(np.arange(ncol) + 1)
(line2,) = ax1.plot(0, 0, "d")
ax1.set_yscale("log")
self.ncol = ncol
self.width = width * 2 + 1
self.fig = fig
self.ax1 = ax1
self.ax2 = ax2
self.ax3 = ax3
self.line1 = line1
self.line2 = line2
[docs] def update_plot1(self, vector, peak, offset=0):
data = np.ones(self.ncol)
data[offset : len(vector) + offset] = np.clip(vector, 1, None)
self.line1.set_ydata(data)
self.line2.set_xdata(peak)
self.line2.set_ydata(data[peak])
self.ax1.set_ylim((data.min(), data.max()))
self.fig.canvas.draw()
self.fig.canvas.flush_events()
[docs] def update_plot2(self, img, model, tilt, shear, peak):
self.ax2.clear()
self.ax3.clear()
self.ax2.imshow(img)
self.ax3.imshow(model)
nrows, _ = img.shape
middle = nrows // 2
y = np.arange(-middle, -middle + nrows)
x = peak + (tilt + shear * y) * y
y += middle
self.ax2.plot(x, y, "r")
self.ax3.plot(x, y, "r")
self.fig.canvas.draw()
self.fig.canvas.flush_events()
[docs] def close(self):
plt.close()
plt.ioff()
[docs]class Curvature:
def __init__(
self,
orders,
extraction_width=0.5,
column_range=None,
order_range=None,
window_width=9,
peak_threshold=10,
peak_width=1,
fit_degree=2,
sigma_cutoff=3,
mode="1D",
plot=False,
plot_title=None,
peak_function="gaussian",
curv_degree=2,
):
self.orders = orders
self.extraction_width = extraction_width
self.column_range = column_range
if order_range is None:
order_range = (0, self.nord)
self.order_range = order_range
self.window_width = window_width
self.threshold = peak_threshold
self.peak_width = peak_width
self.fit_degree = fit_degree
self.sigma_cutoff = sigma_cutoff
self.mode = mode
self.plot = plot
self.plot_title = plot_title
self.curv_degree = curv_degree
self.peak_function = peak_function
if self.mode == "1D":
# fit degree is an integer
if not np.isscalar(self.fit_degree):
self.fit_degree = self.fit_degree[0]
elif self.mode == "2D":
# fit degree is a 2 tuple
if np.isscalar(self.fit_degree):
self.fit_degree = (self.fit_degree, self.fit_degree)
@property
def nord(self):
return self.orders.shape[0]
@property
def n(self):
return self.order_range[1] - self.order_range[0]
@property
def mode(self):
return self._mode
@mode.setter
def mode(self, value):
if value not in ["1D", "2D"]:
raise ValueError(
f"Value for 'mode' not understood. Expected one of ['1D', '2D'] but got {value}"
)
self._mode = value
def _fix_inputs(self, original):
orders = self.orders
extraction_width = self.extraction_width
column_range = self.column_range
nrow, ncol = original.shape
nord = len(orders)
extraction_width, column_range, orders = fix_parameters(
extraction_width, column_range, orders, nrow, ncol, nord
)
self.column_range = column_range[self.order_range[0] : self.order_range[1]]
self.extraction_width = extraction_width[
self.order_range[0] : self.order_range[1]
]
self.orders = orders[self.order_range[0] : self.order_range[1]]
self.order_range = (0, self.n)
def _find_peaks(self, vec, cr):
# This should probably be the same as in the wavelength calibration
vec -= np.ma.median(vec)
vec = np.ma.filled(vec, 0)
height = np.percentile(vec, 68) * self.threshold
peaks, _ = signal.find_peaks(
vec, prominence=height, width=self.peak_width, distance=self.window_width
)
# Remove peaks at the edge
peaks = peaks[
(peaks >= self.window_width + 1)
& (peaks < len(vec) - self.window_width - 1)
]
# Remove the offset, due to vec being a subset of extracted
peaks += cr[0]
return vec, peaks
def _determine_curvature_single_line(self, original, peak, ycen, ycen_int, xwd):
"""
Fit the curvature of a single peak in the spectrum
This is achieved by fitting a model, that consists of gaussians
in spectrum direction, that are shifted by the curvature in each row.
Parameters
----------
original : array of shape (nrows, ncols)
whole input image
peak : int
column position of the peak
ycen : array of shape (ncols,)
row center of the order of the peak
xwd : 2 tuple
extraction width above and below the order center to use
Returns
-------
tilt : float
first order curvature
shear : float
second order curvature
"""
_, ncol = original.shape
# look at +- width pixels around the line
# Extract short horizontal strip for each row in extraction width
# Then fit a gaussian to each row, to find the center of the line
x = peak + np.arange(-self.window_width, self.window_width + 1)
x = x[(x >= 0) & (x < ncol)]
xmin, xmax = x[0], x[-1] + 1
# Look above and below the line center
y = np.arange(-xwd[0], xwd[1] + 1)[:, None] - ycen[xmin:xmax][None, :]
x = x[None, :]
idx = make_index(ycen_int - xwd[0], ycen_int + xwd[1], xmin, xmax)
img = original[idx]
img_compressed = np.ma.compressed(img)
img -= np.percentile(img_compressed, 1)
img /= np.percentile(img_compressed, 99)
img = np.ma.clip(img, 0, 1)
sl = np.ma.mean(img, axis=1)
sl = sl[:, None]
peak_func = {"gaussian": gaussian, "lorentzian": lorentzian}
peak_func = peak_func[self.peak_function]
def model(coef):
A, middle, sig, *curv = coef
mu = middle + shift(curv)
mod = peak_func(x, A, mu, sig)
mod *= sl
return (mod - img).ravel()
def model_compressed(coef):
return np.ma.compressed(model(coef))
A = np.nanpercentile(img_compressed, 95)
sig = (xmax - xmin) / 4 # TODO
if self.curv_degree == 1:
shift = lambda curv: curv[0] * y
elif self.curv_degree == 2:
shift = lambda curv: (curv[0] + curv[1] * y) * y
else:
raise ValueError("Only curvature degrees 1 and 2 are supported")
# res = least_squares(model, x0=[A, middle, sig, 0], loss="soft_l1", bounds=([0, xmin, 1, -10],[np.inf, xmax, xmax, 10]))
x0 = [A, peak, sig] + [0] * self.curv_degree
res = least_squares(
model_compressed, x0=x0, method="trf", loss="soft_l1", f_scale=0.1
)
if self.curv_degree == 1:
tilt, shear = res.x[3], 0
elif self.curv_degree == 2:
tilt, shear = res.x[3], res.x[4]
else:
tilt, shear = 0, 0
# model = model(res.x).reshape(img.shape) + img
# vmin = 0
# vmax = np.max(model)
# y = y.ravel()
# x = res.x[1] - xmin + (tilt + shear * y) * y
# y += xwd[0]
# plt.subplot(121)
# plt.imshow(img, vmin=vmin, vmax=vmax, origin="lower")
# plt.plot(xwd[0] + ycen[xmin:xmax], "r")
# plt.title("Input Image")
# plt.xlabel("x [pixel]")
# plt.ylabel("y [pixel]")
# plt.subplot(122)
# plt.imshow(model, vmin=vmin, vmax=vmax, origin="lower")
# plt.plot(x, y, "r", label="curvature")
# plt.ylim((-0.5, model.shape[0] - 0.5))
# plt.title("Model")
# plt.xlabel("x [pixel]")
# plt.ylabel("y [pixel]")
# plt.show()
if self.plot >= 2:
model = res.fun.reshape(img.shape) + img
self.progress.update_plot2(img, model, tilt, shear, res.x[1] - xmin)
return tilt, shear
def _fit_curvature_single_order(self, peaks, tilt, shear):
try:
middle = np.median(tilt)
sigma = np.percentile(tilt, (32, 68))
sigma = middle - sigma[0], sigma[1] - middle
mask = (tilt >= middle - 5 * sigma[0]) & (tilt <= middle + 5 * sigma[1])
peaks, tilt, shear = peaks[mask], tilt[mask], shear[mask]
coef_tilt = np.zeros(self.fit_degree + 1)
res = least_squares(
lambda coef: np.polyval(coef, peaks) - tilt,
x0=coef_tilt,
loss="arctan",
)
coef_tilt = res.x
coef_shear = np.zeros(self.fit_degree + 1)
res = least_squares(
lambda coef: np.polyval(coef, peaks) - shear,
x0=coef_shear,
loss="arctan",
)
coef_shear = res.x
except:
logger.error(
"Could not fit the curvature of this order. Using no curvature instead"
)
coef_tilt = np.zeros(self.fit_degree + 1)
coef_shear = np.zeros(self.fit_degree + 1)
return coef_tilt, coef_shear, peaks
def _determine_curvature_all_lines(self, original, extracted):
ncol = original.shape[1]
# Store data from all orders
all_peaks = []
all_tilt = []
all_shear = []
plot_vec = []
for j in tqdm(range(self.n), desc="Order"):
logger.debug("Calculating tilt of order %i out of %i", j + 1, self.n)
cr = self.column_range[j]
xwd = self.extraction_width[j]
ycen = np.polyval(self.orders[j], np.arange(ncol))
ycen_int = ycen.astype(int)
ycen -= ycen_int
# Find peaks
vec = extracted[j, cr[0] : cr[1]]
vec, peaks = self._find_peaks(vec, cr)
npeaks = len(peaks)
# Determine curvature for each line seperately
tilt = np.zeros(npeaks)
shear = np.zeros(npeaks)
mask = np.full(npeaks, True)
for ipeak, peak in tqdm(
enumerate(peaks), total=len(peaks), desc="Peak", leave=False
):
if self.plot >= 2: # pragma: no cover
self.progress.update_plot1(vec, peak, cr[0])
try:
tilt[ipeak], shear[ipeak] = self._determine_curvature_single_line(
original, peak, ycen, ycen_int, xwd
)
except RuntimeError: # pragma: no cover
mask[ipeak] = False
# Store results
all_peaks += [peaks[mask]]
all_tilt += [tilt[mask]]
all_shear += [shear[mask]]
plot_vec += [vec]
return all_peaks, all_tilt, all_shear, plot_vec
[docs] def fit(self, peaks, tilt, shear):
if self.mode == "1D":
coef_tilt = np.zeros((self.n, self.fit_degree + 1))
coef_shear = np.zeros((self.n, self.fit_degree + 1))
for j in range(self.n):
coef_tilt[j], coef_shear[j], _ = self._fit_curvature_single_order(
peaks[j], tilt[j], shear[j]
)
elif self.mode == "2D":
x = np.concatenate(peaks)
y = [np.full(len(p), i) for i, p in enumerate(peaks)]
y = np.concatenate(y)
z = np.concatenate(tilt)
coef_tilt = polyfit2d(x, y, z, degree=self.fit_degree, loss="arctan")
z = np.concatenate(shear)
coef_shear = polyfit2d(x, y, z, degree=self.fit_degree, loss="arctan")
return coef_tilt, coef_shear
[docs] def eval(self, peaks, order, coef_tilt, coef_shear):
if self.mode == "1D":
tilt = np.zeros(peaks.shape)
shear = np.zeros(peaks.shape)
for i in np.unique(order):
idx = order == i
tilt[idx] = np.polyval(coef_tilt[i], peaks[idx])
shear[idx] = np.polyval(coef_shear[i], peaks[idx])
elif self.mode == "2D":
tilt = polyval2d(peaks, order, coef_tilt)
shear = polyval2d(peaks, order, coef_shear)
return tilt, shear
[docs] def plot_results(
self, ncol, plot_peaks, plot_vec, plot_tilt, plot_shear, tilt_x, shear_x
): # pragma: no cover
fig, axes = plt.subplots(nrows=self.n // 2 + self.n % 2, ncols=2, squeeze=False)
title = "Peaks"
if self.plot_title is not None:
title = f"{self.plot_title}\n{title}"
fig.suptitle(title)
fig1, axes1 = plt.subplots(
nrows=self.n // 2 + self.n % 2, ncols=2, squeeze=False
)
title = "1st Order Curvature"
if self.plot_title is not None:
title = f"{self.plot_title}\n{title}"
fig1.suptitle(title)
fig2, axes2 = plt.subplots(
nrows=self.n // 2 + self.n % 2, ncols=2, squeeze=False
)
title = "2nd Order Curvature"
if self.plot_title is not None:
title = f"{self.plot_title}\n{title}"
fig2.suptitle(title)
plt.subplots_adjust(hspace=0)
def trim_axs(axs, N):
"""little helper to massage the axs list to have correct length..."""
axs = axs.flat
for ax in axs[N:]:
ax.remove()
return axs[:N]
t, s = [None for _ in range(self.n)], [None for _ in range(self.n)]
for j in range(self.n):
cr = self.column_range[j]
x = np.arange(cr[0], cr[1])
order = np.full(len(x), j)
t[j], s[j] = self.eval(x, order, tilt_x, shear_x)
t_lower = min(t.min() * (0.5 if t.min() > 0 else 1.5) for t in t)
t_upper = max(t.max() * (1.5 if t.max() > 0 else 0.5) for t in t)
s_lower = min(s.min() * (0.5 if s.min() > 0 else 1.5) for s in s)
s_upper = max(s.max() * (1.5 if s.max() > 0 else 0.5) for s in s)
for j in range(self.n):
cr = self.column_range[j]
peaks = plot_peaks[j]
vec = np.clip(plot_vec[j], 0, None)
tilt = plot_tilt[j]
shear = plot_shear[j]
x = np.arange(cr[0], cr[1])
# Figure Peaks found (and used)
axes[j // 2, j % 2].plot(np.arange(cr[0], cr[1]), vec)
axes[j // 2, j % 2].plot(peaks, vec[peaks - cr[0]], "X")
axes[j // 2, j % 2].set_xlim([0, ncol])
# axes[j // 2, j % 2].set_yscale("log")
if j not in (self.n - 1, self.n - 2):
axes[j // 2, j % 2].get_xaxis().set_ticks([])
# Figure 1st order
axes1[j // 2, j % 2].plot(peaks, tilt, "rX")
axes1[j // 2, j % 2].plot(x, t[j])
axes1[j // 2, j % 2].set_xlim(0, ncol)
axes1[j // 2, j % 2].set_ylim(t_lower, t_upper)
if j not in (self.n - 1, self.n - 2):
axes1[j // 2, j % 2].get_xaxis().set_ticks([])
else:
axes1[j // 2, j % 2].set_xlabel("x [pixel]")
if j == self.n // 2 + 1:
axes1[j // 2, j % 2].set_ylabel("tilt [pixel/pixel]")
# Figure 2nd order
axes2[j // 2, j % 2].plot(peaks, shear, "rX")
axes2[j // 2, j % 2].plot(x, s[j])
axes2[j // 2, j % 2].set_xlim(0, ncol)
axes2[j // 2, j % 2].set_ylim(s_lower, s_upper)
if j not in (self.n - 1, self.n - 2):
axes2[j // 2, j % 2].get_xaxis().set_ticks([])
else:
axes2[j // 2, j % 2].set_xlabel("x [pixel]")
if j == self.n // 2 + 1:
axes2[j // 2, j % 2].set_ylabel("shear [pixel/pixel**2]")
axes1 = trim_axs(axes1, self.n)
axes2 = trim_axs(axes2, self.n)
plt.show()
[docs] def plot_comparison(self, original, tilt, shear, peaks): # pragma: no cover
_, ncol = original.shape
output = np.zeros((np.sum(self.extraction_width) + self.nord, ncol))
pos = [0]
x = np.arange(ncol)
for i in range(self.nord):
ycen = np.polyval(self.orders[i], x)
yb = ycen - self.extraction_width[i, 0]
yt = ycen + self.extraction_width[i, 1]
xl, xr = self.column_range[i]
index = make_index(yb, yt, xl, xr)
yl = pos[i]
yr = pos[i] + index[0].shape[0]
output[yl:yr, xl:xr] = original[index]
pos += [yr]
vmin, vmax = np.percentile(output[output != 0], (5, 95))
plt.imshow(output, vmin=vmin, vmax=vmax, origin="lower", aspect="auto")
for i in range(self.nord):
for p in peaks[i]:
ew = self.extraction_width[i]
x = np.zeros(ew[0] + ew[1] + 1)
y = np.arange(-ew[0], ew[1] + 1)
for j, yt in enumerate(y):
x[j] = p + yt * tilt[i, p] + yt ** 2 * shear[i, p]
y += pos[i] + ew[0]
plt.plot(x, y, "r")
locs = np.sum(self.extraction_width, axis=1) + 1
locs = np.array([0, *np.cumsum(locs)[:-1]])
locs[:-1] += (np.diff(locs) * 0.5).astype(int)
locs[-1] += ((output.shape[0] - locs[-1]) * 0.5).astype(int)
plt.yticks(locs, range(len(locs)))
if self.plot_title is not None:
plt.title(self.plot_title)
plt.xlabel("x [pixel]")
plt.ylabel("order")
plt.show()
[docs] def execute(self, extracted, original):
logger.info("Determining the Slit Curvature")
_, ncol = original.shape
self._fix_inputs(original)
if self.plot >= 2: # pragma: no cover
self.progress = ProgressPlot(ncol, self.window_width, title=self.plot_title)
peaks, tilt, shear, vec = self._determine_curvature_all_lines(
original, extracted
)
coef_tilt, coef_shear = self.fit(peaks, tilt, shear)
if self.plot >= 2: # pragma: no cover
self.progress.close()
if self.plot: # pragma: no cover
self.plot_results(ncol, peaks, vec, tilt, shear, coef_tilt, coef_shear)
iorder, ipeaks = np.indices(extracted.shape)
tilt, shear = self.eval(ipeaks, iorder, coef_tilt, coef_shear)
if self.plot: # pragma: no cover
self.plot_comparison(original, tilt, shear, peaks)
return tilt, shear
# TODO allow other line shapes
[docs]def gaussian(x, A, mu, sig):
"""
A: height
mu: offset from central line
sig: standard deviation
"""
return A * np.exp(-np.power(x - mu, 2.0) / (2 * np.power(sig, 2.0)))
[docs]def lorentzian(x, A, x0, mu):
"""
A: height
x0: offset from central line
mu: width of lorentzian
"""
return A * mu / ((x - x0) ** 2 + 0.25 * mu ** 2)