"""
Wavelength Calibration
by comparison to a reference spectrum
Loosely bases on the IDL wavecal function
"""
import logging
from os.path import dirname, exists, join
import matplotlib.pyplot as plt
import numpy as np
from astropy.io import fits
from numpy.polynomial.polynomial import Polynomial, polyval2d
from scipy import signal
from scipy.constants import speed_of_light
from scipy.interpolate import interp1d
from scipy.ndimage.filters import gaussian_filter1d
from scipy.ndimage.morphology import grey_closing
from scipy.optimize import curve_fit
from tqdm import tqdm
from . import util
logger = logging.getLogger(__name__)
[docs]
def polyfit(x, y, deg):
res = Polynomial.fit(x, y, deg, domain=[])
coef = res.coef[::-1]
return coef
[docs]
class AlignmentPlot:
"""
Makes a plot which can be clicked to align the two spectra, reference and observed
"""
def __init__(self, ax, obs, lines, offset=(0, 0), plot_title=None):
self.im = ax
self.first = True
self.ntrace, self.ncol = obs.shape
self.RED, self.GREEN, self.BLUE = 0, 1, 2
self.obs = obs
self.lines = lines
self.plot_title = plot_title
self.order_first = 0
self.spec_first = ""
self.x_first = 0
self.offset = list(offset)
self.make_ref_image()
[docs]
def make_ref_image(self):
"""create and show the reference plot, with the two spectra"""
ref_image = np.zeros((self.ntrace * 2, self.ncol, 3))
for idx in range(self.ntrace):
ref_image[idx * 2, :, self.RED] = 10 * np.ma.filled(self.obs[idx], 0)
if 0 <= idx + self.offset[0] < self.ntrace:
for line in self.lines[self.lines["order"] == idx]:
first = int(np.clip(line["xfirst"] + self.offset[1], 0, self.ncol))
last = int(np.clip(line["xlast"] + self.offset[1], 0, self.ncol))
order = (idx + self.offset[0]) * 2 + 1
ref_image[order, first:last, self.GREEN] = (
10
* line["height"]
* signal.windows.gaussian(last - first, line["width"])
)
ref_image = np.clip(ref_image, 0, 1)
ref_image[ref_image < 0.1] = 0
self.im.imshow(
ref_image,
aspect="auto",
origin="lower",
extent=(-0.5, self.ncol - 0.5, -0.5, self.ntrace - 0.5),
)
title = "Alignment, Observed: RED, Reference: GREEN\nGreen should be above red!"
if self.plot_title is not None:
title = f"{self.plot_title}\n{title}"
self.im.figure.suptitle(title)
self.im.axes.set_xlabel("x [pixel]")
self.im.axes.set_ylabel("Order")
self.im.figure.canvas.draw()
[docs]
def connect(self):
"""connect the click event with the appropiate function"""
self.cidclick = self.im.figure.canvas.mpl_connect(
"button_press_event", self.on_click
)
[docs]
def on_click(self, event):
"""On click offset the reference by the distance between click positions"""
if event.ydata is None:
return
order = int(np.floor(event.ydata))
spec = "ref" if (event.ydata - order) > 0.5 else "obs" # if True then reference
x = event.xdata
print("Order: %i, Spectrum: %s, x: %g" % (order, "ref" if spec else "obs", x))
# on every second click
if self.first:
self.first = False
self.order_first = order
self.spec_first = spec
self.x_first = x
else:
# Clicked different spectra
if spec != self.spec_first:
self.first = True
direction = -1 if spec == "ref" else 1
offset_orders = int(order - self.order_first) * direction
offset_x = int(x - self.x_first) * direction
self.offset[0] -= offset_orders - 1
self.offset[1] -= offset_x
self.make_ref_image()
[docs]
class LineAtlas:
def __init__(self, element, medium="vac", search_dirs=None):
self.element = element
self.medium = medium
default_dir = join(dirname(__file__), "instruments", "defaults", "atlas")
if search_dirs is None:
dirs = [default_dir]
else:
dirs = list(search_dirs) + [default_dir]
base_fits = element.lower() + ".fits"
base_list = element.lower() + "_list.txt"
fname_fits = fname_list = None
for d in dirs:
if fname_fits is None and exists(join(d, base_fits)):
fname_fits = join(d, base_fits)
if fname_list is None and exists(join(d, base_list)):
fname_list = join(d, base_list)
has_fits = fname_fits is not None
has_list = fname_list is not None
if not has_fits and not has_list:
searched = ", ".join(dirs)
raise FileNotFoundError(
f"No atlas files found for '{element}' "
f"(looked for {base_fits} and {base_list} in: {searched})"
)
if has_list:
raw = np.genfromtxt(fname_list, dtype=str, ndmin=2)
wpos = raw[:, 0].astype(np.float64)
if raw.shape[1] > 1:
elem_ids = raw[:, 1]
else:
elem_ids = np.full(len(wpos), element)
if has_fits:
self.wave, self.flux = self.load_fits(fname_fits)
if has_list:
indices = self.wave.searchsorted(wpos)
heights = self.flux[indices]
self.linelist = np.rec.fromarrays(
[wpos, heights, elem_ids], names=["wave", "heights", "element"]
)
else:
logger.warning(
"No dedicated linelist found for %s, determining peaks from spectrum.",
element,
)
module = WavelengthCalibration(plot=False)
n, peaks = module._find_peaks(self.flux)
wpos = np.interp(peaks, np.arange(len(self.wave)), self.wave)
elem_ids = np.full(len(wpos), element)
indices = self.wave.searchsorted(wpos)
heights = self.flux[indices]
self.linelist = np.rec.fromarrays(
[wpos, heights, elem_ids], names=["wave", "heights", "element"]
)
else:
# Only line list, no FITS — synthesize a reference spectrum
logger.info(
"No reference spectrum for %s, synthesizing from line list.", element
)
self.wave, self.flux = self._synthesize_spectrum(wpos)
heights = np.ones(len(wpos))
self.linelist = np.rec.fromarrays(
[wpos, heights, elem_ids], names=["wave", "heights", "element"]
)
if medium == "air":
self.wave = util.vac2air(self.wave)
self.linelist["wave"] = util.vac2air(self.linelist["wave"])
@staticmethod
def _synthesize_spectrum(wpos, n=10_000, width=5):
"""Build a synthetic reference spectrum from a list of line wavelengths."""
wmin, wmax = wpos.min(), wpos.max()
margin = (wmax - wmin) * 0.01
wmin -= margin
wmax += margin
wave = np.linspace(wmin, wmax, num=n, endpoint=True)
flux = np.zeros(n)
idx = np.searchsorted(wave, wpos)
half = int(width * 5)
for i in range(len(wpos)):
mid = idx[i]
lo = max(mid - half, 0)
hi = min(mid + half, n)
if hi > lo:
flux[lo:hi] += signal.windows.gaussian(hi - lo, width)
flux = np.clip(flux, 0, None)
peak = flux.max()
if peak > 0:
flux /= peak
return wave, flux
[docs]
def load_fits(self, fname):
with fits.open(fname, memmap=False) as hdu:
if len(hdu) == 1:
# Its just the spectrum
# with the wavelength defined via the header keywords
header = hdu[0].header
spec = hdu[0].data.ravel()
wmin = header["CRVAL1"]
wdel = header["CDELT1"]
wave = np.arange(spec.size) * wdel + wmin
else:
# Its a binary Table, with two columns for the wavelength and the
# spectrum
data = hdu[1].data
wave = data["wave"]
spec = data["spec"]
spec = np.nan_to_num(spec, nan=0.0)
smax = np.max(spec)
if smax > 0:
spec /= smax
spec = np.clip(spec, 0, None)
return wave, spec
[docs]
class LineList:
dtype = np.dtype(
(
np.record,
[
(("wlc", "WLC"), ">f8"), # Wavelength (before fit)
(("wll", "WLL"), ">f8"), # Wavelength (after fit)
(("posc", "POSC"), ">f8"), # Pixel Position (before fit)
(("posm", "POSM"), ">f8"), # Pixel Position (after fit)
(("xfirst", "XFIRST"), ">i2"), # first pixel of the line
(("xlast", "XLAST"), ">i2"), # last pixel of the line
(
("approx", "APPROX"),
"O",
), # Not used. Describes the shape used to approximate the line. "G" for Gaussian
(("width", "WIDTH"), ">f8"), # width of the line in pixels
(("height", "HEIGHT"), ">f8"), # relative strength of the line
(("order", "ORDER"), ">i2"), # row index in the wavecal spectrum
(
("bundle", "BUNDLE"),
">i2",
), # bundle id (single-order multi-bundle), -1 if not bundled
("flag", "?"), # flag that tells us if we should use the line or not
],
)
)
def __init__(self, lines=None, obase=None):
if lines is None:
lines = np.array([], dtype=self.dtype)
self.data = lines
self.dtype = self.data.dtype
self.obase = obase # Base spectral order number
def __getitem__(self, key):
return self.data[key]
def __setitem__(self, key, value):
self.data[key] = value
def __len__(self):
return len(self.data)
[docs]
@classmethod
def load(cls, filename):
data = np.load(filename, allow_pickle=True)
linelist = cls(data["cs_lines"])
# Load obase if present
if "obase" in data:
linelist.obase = int(data["obase"])
return linelist
[docs]
def save(self, filename):
if self.obase is not None:
np.savez(filename, cs_lines=self.data, obase=self.obase)
else:
np.savez(filename, cs_lines=self.data)
[docs]
def append(self, linelist):
if isinstance(linelist, LineList):
linelist = linelist.data
self.data = np.append(self.data, linelist)
[docs]
def add_line(self, wave, order, pos, width, height, flag, bundle=-1):
lines = self.from_list(
[wave], [order], [pos], [width], [height], [flag], bundle=[bundle]
)
self.data = np.append(self.data, lines)
[docs]
@classmethod
def from_list(cls, wave, order, pos, width, height, flag, bundle=None):
if bundle is None:
bundle = [-1] * len(wave)
lines = [
(w, w, p, p, p - wi / 2, p + wi / 2, b"G", wi, h, o, b, f)
for w, o, p, wi, h, b, f in zip(
wave, order, pos, width, height, bundle, flag, strict=False
)
]
lines = np.array(lines, dtype=cls.dtype)
return cls(lines)
[docs]
class WavelengthCalibration:
"""
Wavelength Calibration Module
Takes an observed wavelength image and the reference linelist
and returns the wavelength at each pixel
"""
def __init__(
self,
threshold=100,
degree=(6, 6),
iterations=3,
dimensionality="2D",
nstep=0,
correlate_cols=0,
shift_window=0.01,
manual=False,
polarim=False,
lfc_peak_width=3,
closing=5,
atlas_name=None,
atlas_search_dirs=None,
medium="vac",
plot=True,
plot_title=None,
# deprecated alias
element=None,
):
if element is not None and atlas_name is None:
atlas_name = element
#:float: Residual threshold in m/s above which to remove lines
self.threshold = threshold
#:tuple(int, int): polynomial degree of the wavelength fit in (pixel, order) direction
self.degree = degree
if dimensionality == "1D":
self.degree = int(degree)
elif dimensionality == "2D":
self.degree = (int(degree[0]), int(degree[1]))
#:int: Number of iterations in the remove residuals, auto id, loop
self.iterations = iterations
#:{"1D", "2D"}: Whether to use 1d or 2d fit
self.dimensionality = dimensionality
#:bool: Whether to fit for pixel steps (offsets) in the detector
self.nstep = nstep
#:int: How many columns to use in the 2D cross correlation alignment. 0 means all pixels (slow).
self.correlate_cols = correlate_cols
#:float: Fraction if the number of columns to use in the alignment of individual orders. Set to 0 to disable
self.shift_window = shift_window
#:bool: Whether to manually align the reference instead of using cross correlation
self.manual = manual
#:bool: Whether to use polarimetric orders instead of the usual ones. I.e. Each pair of two orders represents the same data. Not Supported yet
self.polarim = polarim
#:int: Whether to plot the results. Set to 2 to plot during all steps.
self.plot = plot
self.plot_title = plot_title
#:str: Name of the line atlas used for calibration
self.atlas_name = atlas_name
#:list: Directories to search for atlas files (before the default)
self.atlas_search_dirs = atlas_search_dirs
#:str: Medium of the detector, vac or air
self.medium = medium
#:int: Laser Frequency Peak width (for scipy.signal.find_peaks)
self.lfc_peak_width = lfc_peak_width
#:int: grey closing range for the input image
self.closing = 5
#:int: Number of orders in the observation
self.ntrace = None
#:int: Number of columns in the observation
self.ncol = None
@property
def step_mode(self):
return self.nstep > 0
@property
def dimensionality(self):
"""{"1D", "2D"}: Whether to use 1D or 2D polynomials for the wavelength solution"""
return self._dimensionality
@dimensionality.setter
def dimensionality(self, value):
accepted_values = ["1D", "2D"]
if value in accepted_values:
self._dimensionality = value
else:
raise ValueError(
f"Value for 'dimensionality' not understood. Expected one of {accepted_values} but got {value} instead"
)
[docs]
def normalize(self, obs, lines):
"""
Normalize the observation and reference list in each order individually
Copies the data if the image, but not of the linelist
Parameters
----------
obs : array of shape (ntrace, ncol)
observed image
lines : recarray of shape (nlines,)
reference linelist
Returns
-------
obs : array of shape (ntrace, ncol)
normalized image
lines : recarray of shape (nlines,)
normalized reference linelist
"""
# normalize order by order
obs = np.ma.masked_invalid(obs)
for i in range(len(obs)):
if self.closing > 0:
mask = obs.mask[i].copy()
obs[i] = grey_closing(obs[i], self.closing)
obs.mask[i] = mask | np.isnan(obs[i])
try:
obs[i] -= np.ma.median(obs[i][obs[i] > 0])
except ValueError:
logger.warning(
"Could not determine the minimum value in order %i. No positive values found",
i,
)
obs[i] /= np.ma.max(obs[i])
# Remove negative outliers
std = np.std(obs, axis=1)[:, None]
obs[obs <= -2 * std] = np.ma.masked
# obs[obs <= 0] = np.ma.masked
# Normalize lines in each order
for order in np.unique(lines["order"]):
select = lines["order"] == order
topheight = np.max(lines[select]["height"])
lines["height"][select] /= topheight
return obs, lines
[docs]
def create_image_from_lines(self, lines):
"""
Create a reference image based on a line list
Each line will be approximated by a Gaussian
Space inbetween lines is 0
Parameters
----------
lines : recarray of shape (nlines,)
line data
Returns
-------
img : array of shape (ntrace, ncol)
New reference image
"""
# Use self.ntrace rows so the image matches the observation shape.
# This prevents the cross-correlation alignment from computing a
# spurious order offset when lines don't span all orders.
img = np.zeros((self.ntrace, self.ncol))
for line in lines:
order = int(line["order"])
if order < 0 or order >= self.ntrace:
continue
if line["xlast"] < 0 or line["xfirst"] > self.ncol:
continue
first = int(max(line["xfirst"], 0))
last = int(min(line["xlast"], self.ncol))
img[order, first:last] = line["height"] * signal.windows.gaussian(
last - first, line["width"]
)
return img
[docs]
def align_manual(self, obs, lines):
"""
Open an AlignmentPlot window for manual selection of the alignment
Parameters
----------
obs : array of shape (ntrace, ncol)
observed image
lines : recarray of shape (nlines,)
reference linelist
Returns
-------
offset : tuple(int, int)
offset in order and column to be applied to each line in the linelist
"""
_, ax = plt.subplots()
ap = AlignmentPlot(ax, obs, lines, plot_title=self.plot_title)
ap.connect()
util.show_or_save("wavecal_alignment")
offset = ap.offset
return offset
[docs]
def apply_alignment_offset(self, lines, offset, select=None):
"""
Apply an offset to the linelist
Parameters
----------
lines : recarray of shape (nlines,)
reference linelist
offset : tuple(int, int)
offset in (order, column)
select : array of shape(nlines,), optional
Mask that defines which lines the offset applies to
Returns
-------
lines : recarray of shape (nlines,)
linelist with offset applied
"""
if select is None:
select = slice(None)
lines["xfirst"][select] += offset[1]
lines["xlast"][select] += offset[1]
lines["posm"][select] += offset[1]
lines["order"][select] += offset[0]
return lines
[docs]
def align(self, obs, lines):
"""
Align the observation with the reference spectrum
Either automatically using cross correlation or manually (visually)
Parameters
----------
obs : array[nrow, ncol]
observed wavelength calibration spectrum (e.g. obs=ThoriumArgon)
lines : struct_array
reference line data
manual : bool, optional
wether to manually align the spectra (default: False)
plot : bool, optional
wether to plot the alignment (default: False)
Returns
-------
offset: tuple(int, int)
offset in order and column
"""
obs = np.ma.filled(obs, 0)
if not self.manual:
# make image from lines
img = self.create_image_from_lines(lines)
# Crop the image to speed up cross correlation
if self.correlate_cols != 0:
_slice = slice(
(self.ncol - self.correlate_cols) // 2,
(self.ncol + self.correlate_cols) // 2 + 1,
)
ccimg = img[:, _slice]
ccobs = obs[:, _slice]
else:
ccimg, ccobs = img, obs
# Cross correlate with obs image
# And determine overall offset
correlation = signal.correlate2d(ccobs, ccimg, mode="same")
offset_order, offset_x = np.unravel_index(
np.argmax(correlation), correlation.shape
)
if self.plot >= 2:
plt.imshow(correlation, aspect="auto")
plt.vlines(offset_x, -0.5, correlation.shape[0] - 0.5, color="red")
plt.hlines(offset_order, -0.5, correlation.shape[1] - 0.5, color="red")
if self.plot_title is not None:
plt.title(self.plot_title)
util.show_or_save("wavecal_correlation")
offset_order = offset_order - ccimg.shape[0] / 2 + 1
offset_x = offset_x - ccimg.shape[1] / 2 + 1
offset = [int(offset_order), int(offset_x)]
# apply offset
lines = self.apply_alignment_offset(lines, offset)
if self.shift_window != 0:
# Shift individual orders to fit reference
# Only allow a small shift here (1%) ?
img = self.create_image_from_lines(lines)
for i in range(max(offset[0], 0), min(len(obs), len(img))):
correlation = signal.correlate(obs[i], img[i], mode="same")
width = int(self.ncol * self.shift_window) // 2
low, high = self.ncol // 2 - width, self.ncol // 2 + width
offset_x = np.argmax(correlation[low:high]) + low
offset_x = int(offset_x - self.ncol / 2 + 1)
select = lines["order"] == i
lines = self.apply_alignment_offset(lines, (0, offset_x), select)
if self.plot or self.manual:
offset = self.align_manual(obs, lines)
lines = self.apply_alignment_offset(lines, offset)
logger.debug(f"Offset order: {offset[0]}, Offset pixel: {offset[1]}")
return lines
def _fit_single_line(self, obs, center, width, plot=False):
low = int(center - width * 5)
low = max(low, 0)
high = int(center + width * 5)
high = min(high, len(obs))
section = obs[low:high]
x = np.arange(low, high, 1)
x = np.ma.masked_array(x, mask=np.ma.getmaskarray(section))
coef = util.gaussfit2(x, section)
if self.plot >= 2 and plot:
# Limit number of line fit plots (track via class attribute)
if not hasattr(self, "_line_fit_plot_count"):
self._line_fit_plot_count = 0
self._line_fit_plot_count += 1
if self._line_fit_plot_count <= 5:
x2 = np.linspace(x.min(), x.max(), len(x) * 100)
plt.plot(x, section, label="Observation")
plt.plot(x2, util.gaussval2(x2, *coef), label="Fit")
title = f"Gaussian Fit to spectral line ({self._line_fit_plot_count}/5)"
if self.plot_title is not None:
title = f"{self.plot_title}\n{title}"
plt.title(title)
plt.xlabel("x [pixel]")
plt.ylabel("Intensity [a.u.]")
plt.legend()
util.show_or_save("wavecal_line_fit")
elif self._line_fit_plot_count == 6:
logger.info("Skipping remaining line fit plots (shown 5 of many)")
return coef
[docs]
def fit_lines(self, obs, lines):
"""
Determine exact position of each line on the detector based on initial guess
This fits a Gaussian to each line, and uses the peak position as a new solution
Parameters
----------
obs : array of shape (ntrace, ncol)
observed wavelength calibration image
lines : recarray of shape (nlines,)
reference line data
Returns
-------
lines : recarray of shape (nlines,)
Updated line information (posm is changed)
"""
# For each line fit a gaussian to the observation
for i, line in tqdm(
enumerate(lines), total=len(lines), leave=False, desc="Lines"
):
if line["posm"] < 0 or line["posm"] >= obs.shape[1]:
# Line outside pixel range
continue
if line["order"] < 0 or line["order"] >= len(obs):
# Line outside order range
continue
try:
coef = self._fit_single_line(
obs[int(line["order"])],
line["posm"],
line["width"],
plot=line["flag"],
)
lines[i]["posm"] = coef[1]
except:
# Gaussian fit failed, dont use line
lines[i]["flag"] = False
return lines
[docs]
def fit_wavelengths(self, lines, plot=False):
"""Fit a wavelength solution to flagged lines (1D per row or 2D).
Dispatches on ``self.dimensionality``: either a per-row 1D
``np.polyfit`` (rows with too few flagged lines yield NaN coefs and
are skipped) or a single 2D polynomial over (pixel, order).
Parameters
----------
lines : struc_array
line data
plot : bool, optional
whether to plot the solution (default: False)
Returns
-------
coef : array
1D: shape (ntrace, degree+1). 2D: shape (degree_x, degree_y).
"""
if self.step_mode:
return self.build_step_solution(lines, plot=plot)
# Only use flagged data
mask = lines["flag"] # True: use line, False: dont use line
m_wave = lines["wll"][mask]
m_pix = lines["posm"][mask]
m_ord = lines["order"][mask]
if self.dimensionality == "1D":
ntrace = self.ntrace
coef = np.zeros((ntrace, self.degree + 1))
for i in range(ntrace):
select = m_ord == i
n_select = int(np.count_nonzero(select))
if n_select == 0:
coef[i] = np.nan
continue
if n_select < 2:
logger.warning(
"Row %d: only %d flagged line(s); skipping fit", i, n_select
)
coef[i] = np.nan
continue
deg = max(min(self.degree, n_select - 2), 0)
coef[i, -(deg + 1) :] = np.polyfit(
m_pix[select], m_wave[select], deg=deg
)
elif self.dimensionality == "2D":
# 2d polynomial fit with: x = column, y = order, z = wavelength
coef = util.polyfit2d(m_pix, m_ord, m_wave, degree=self.degree, plot=False)
else:
raise ValueError(
f"Parameter 'mode' not understood. Expected '1D' or '2D' but got {self.dimensionality}"
)
if plot or self.plot >= 3: # pragma: no cover
self.plot_residuals(lines, coef, title="Residuals")
return coef
[docs]
def g(self, x, step_coef_pos, step_coef_diff):
try:
bins = step_coef_pos
digits = np.digitize(x, bins) - 1
except ValueError:
return np.inf
cumsum = np.cumsum(step_coef_diff)
x = x + cumsum[digits]
return x
[docs]
def f(self, x, poly_coef, step_coef_pos, step_coef_diff):
xdash = self.g(x, step_coef_pos, step_coef_diff)
if np.all(np.isinf(xdash)):
return np.inf
y = np.polyval(poly_coef, xdash)
return y
[docs]
def build_step_solution(self, lines, plot=False):
"""
Fit the least squares fit to the wavelength points,
with additional free parameters for detector gaps, e.g. due to stitching.
The exact method of the fit depends on the dimensionality.
Either way we are using the usual polynomial fit for the wavelength, but
the x points are modified beforehand by shifting them some amount, at specific
indices. We assume that the stitching effects are distributed evenly and we know how
many steps we expect (this is set as "nstep").
Parameters
----------
lines : np.recarray
linedata
plot : bool, optional
whether to plot results or not, by default False
Returns
-------
coef
coefficients of the best fit
"""
mask = lines["flag"] # True: use line, False: dont use line
m_wave = lines["wll"][mask]
m_pix = lines["posm"][mask]
m_ord = lines["order"][mask]
nstep = self.nstep
ncol = self.ncol
if self.dimensionality == "1D":
coef = {}
for order in np.unique(m_ord):
select = m_ord == order
x = xl = m_pix[select]
y = m_wave[select]
step_coef = np.zeros((nstep, 2))
step_coef[:, 0] = np.linspace(ncol / (nstep + 1), ncol, nstep + 1)[:-1]
def func(x, *param):
return self.f(x, poly_coef, step_coef[:, 0], param) # noqa: B023
for _ in range(5):
poly_coef = np.polyfit(xl, y, self.degree)
res, _ = curve_fit(func, x, y, p0=step_coef[:, 1], bounds=[-1, 1])
step_coef[:, 1] = res
xl = self.g(x, step_coef[:, 0], step_coef[:, 1])
coef[order] = [poly_coef, step_coef]
elif self.dimensionality == "2D":
unique = np.unique(m_ord)
ntrace = len(unique)
shape = (self.degree[0] + 1, self.degree[1] + 1)
np.prod(shape)
step_coef = np.zeros((ntrace, nstep, 2))
step_coef[:, :, 0] = np.linspace(ncol / (nstep + 1), ncol, nstep + 1)[:-1]
def func(x, *param):
x, y = x[: len(x) // 2], x[len(x) // 2 :]
theta = np.asarray(param).reshape((ntrace, nstep))
xl = np.copy(x)
for j, i in enumerate(unique):
xl[y == i] = self.g(x[y == i], step_coef[j, :, 0], theta[j])
z = polyval2d(xl, y, poly_coef)
return z
# TODO: this could use some optimization
x = np.copy(m_pix)
x0 = np.concatenate((m_pix, m_ord))
resid_old = np.inf
for k in tqdm(range(5)):
poly_coef = util.polyfit2d(
x, m_ord, m_wave, degree=self.degree, plot=False
)
res, _ = curve_fit(func, x0, m_wave, p0=step_coef[:, :, 1])
step_coef[:, :, 1] = res.reshape((ntrace, nstep))
for j, i in enumerate(unique):
x[m_ord == i] = self.g(
m_pix[m_ord == i], step_coef[j][:, 0], step_coef[j][:, 1]
)
resid = polyval2d(x, m_ord, poly_coef) - m_wave
resid = np.sum(resid**2)
improvement = resid_old - resid
resid_old = resid
logger.info(
"Iteration: %i, Residuals: %.5g, Improvement: %.5g",
k,
resid,
improvement,
)
poly_coef = util.polyfit2d(x, m_ord, m_wave, degree=self.degree, plot=False)
step_coef = {i: step_coef[j] for j, i in enumerate(unique)}
coef = (poly_coef, step_coef)
else:
raise ValueError(
f"Parameter 'dimensionality' not understood. Expected '1D' or '2D' but got {self.dimensionality}"
)
return coef
[docs]
def evaluate_step_solution(self, pos, order, solution):
if not np.array_equal(np.shape(pos), np.shape(order)):
raise ValueError("pos and order must have the same shape")
if self.dimensionality == "1D":
result = np.zeros(pos.shape)
for i in np.unique(order):
select = order == i
result[select] = self.f(
pos[select],
solution[i][0],
solution[i][1][:, 0],
solution[i][1][:, 1],
)
elif self.dimensionality == "2D":
poly_coef, step_coef = solution
pos = np.copy(pos)
for i in np.unique(order):
pos[order == i] = self.g(
pos[order == i], step_coef[i][:, 0], step_coef[i][:, 1]
)
result = polyval2d(pos, order, poly_coef)
else:
raise ValueError(
f"Parameter 'mode' not understood, expected '1D' or '2D' but got {self.dimensionality}"
)
return result
[docs]
def evaluate_solution(self, pos, order, solution):
"""
Evaluate the 1d or 2d wavelength solution at the given pixel positions and orders
Parameters
----------
pos : array
pixel position on the detector (i.e. x axis)
order : array
order of each point
solution : array of shape (ntrace, ndegree) or (degree_x, degree_y)
polynomial coefficients. For mode=1D, one set of coefficients per order.
For mode=2D, the first dimension is for the positions and the second for the orders
mode : str, optional
Wether to interpret the solution as 1D or 2D polynomials, by default "1D"
Returns
-------
result: array
Evaluated polynomial
Raises
------
ValueError
If pos and order have different shapes, or mode is of the wrong value
"""
if not np.array_equal(np.shape(pos), np.shape(order)):
raise ValueError("pos and order must have the same shape")
if self.step_mode:
return self.evaluate_step_solution(pos, order, solution)
if self.dimensionality == "1D":
result = np.zeros(pos.shape)
for i in np.unique(order):
select = order == i
result[select] = np.polyval(solution[int(i)], pos[select])
elif self.dimensionality == "2D":
result = np.polynomial.polynomial.polyval2d(pos, order, solution)
else:
raise ValueError(
f"Parameter 'mode' not understood, expected '1D' or '2D' but got {self.dimensionality}"
)
return result
[docs]
def make_wave(self, wave_solution, plot=False):
"""Expand polynomial wavelength solution into full image
Parameters
----------
wave_solution : array of shape(degree,)
polynomial coefficients of wavelength solution
plot : bool, optional
wether to plot the solution (default: False)
Returns
-------
wave_img : array of shape (ntrace, ncol)
wavelength solution for each point in the spectrum
"""
y, x = np.indices((self.ntrace, self.ncol))
wave_img = self.evaluate_solution(x, y, wave_solution)
return wave_img
[docs]
def auto_id(self, obs, wave_img, lines):
"""Automatically identify peaks that are close to known lines
Parameters
----------
obs : array of shape (ntrace, ncol)
observed spectrum
wave_img : array of shape (ntrace, ncol)
wavelength solution image
lines : struc_array
line data
threshold : int, optional
difference threshold between line positions in m/s, until which a line is considered identified (default: 1)
plot : bool, optional
wether to plot the new lines
Returns
-------
lines : struct_array
line data with new flags
"""
new_lines = []
if self.atlas is not None:
# For each order, find the corresponding section in the Atlas
# Look for strong lines in the atlas and the spectrum that match in position
# Add new lines to the linelist
width_of_atlas_peaks = 3
for order in range(obs.shape[0]):
mask = ~np.ma.getmask(obs[order])
index_mask = np.arange(len(mask))[mask]
data_obs = obs[order, mask]
wave_obs = wave_img[order, mask]
threshold_of_peak_closeness = (
np.diff(wave_obs) / wave_obs[:-1] * speed_of_light
)
threshold_of_peak_closeness = np.max(threshold_of_peak_closeness)
wmin, wmax = wave_obs[0], wave_obs[-1]
imin, imax = np.searchsorted(self.atlas.wave, (wmin, wmax))
wave_atlas = self.atlas.wave[imin:imax]
data_atlas = self.atlas.flux[imin:imax]
if len(data_atlas) == 0:
continue
data_atlas = data_atlas / data_atlas.max()
line = lines[
(lines["order"] == order)
& (lines["wll"] > wmin)
& (lines["wll"] < wmax)
]
peaks_atlas, peak_info_atlas = signal.find_peaks(
data_atlas, height=0.01, width=width_of_atlas_peaks
)
peaks_obs, peak_info_obs = signal.find_peaks(
data_obs, height=0.01, width=0
)
for _, p in enumerate(peaks_atlas):
# Look for an existing line in the vicinityq
wpeak = wave_atlas[p]
diff = np.abs(line["wll"] - wpeak) / wpeak * speed_of_light
if np.any(diff < threshold_of_peak_closeness):
# Line already in the linelist, ignore
continue
else:
# Look for matching peak in observation
diff = (
np.abs(wpeak - wave_obs[peaks_obs]) / wpeak * speed_of_light
)
imin = np.argmin(diff)
if diff[imin] < threshold_of_peak_closeness:
# Add line to linelist
# Location on the detector
# Include the masked areas!!!
ipeak = peaks_obs[imin]
ipeak = index_mask[ipeak]
# relative height of the peak
hpeak = data_obs[peaks_obs[imin]]
wipeak = peak_info_obs["widths"][imin]
# wave, order, pos, width, height, flag
new_lines.append([wpeak, order, ipeak, wipeak, hpeak, True])
# Add new lines to the linelist
if len(new_lines) != 0:
new_lines = np.array(new_lines).T
new_lines = LineList.from_list(*new_lines)
new_lines = self.fit_lines(obs, new_lines)
lines.append(new_lines)
# Option 1:
# Step 1: Loop over unused lines in lines
# Step 2: find peaks in neighbourhood
# Step 3: Toggle flag on if close
counter = 0
for i, line in enumerate(lines):
if line["flag"]:
# Line is already in use
continue
if line["order"] < 0 or line["order"] >= self.ntrace:
# Line outside order range
continue
idx = int(line["order"])
if line["wll"] < wave_img[idx][0] or line["wll"] >= wave_img[idx][-1]:
# Line outside pixel range
continue
wl = line["wll"]
width = line["width"] * 5
wave = wave_img[idx]
order_obs = obs[idx]
# Find where the line should be
try:
idx = np.digitize(wl, wave)
except ValueError:
# Wavelength solution is not monotonic
idx = np.where(wave >= wl)[0][0]
low = int(idx - width)
low = max(low, 0)
high = int(idx + width)
high = min(high, len(order_obs))
vec = order_obs[low:high]
if np.all(np.ma.getmaskarray(vec)):
continue
# Find the best fitting peak
# TODO use gaussian fit?
peak_idx, _ = signal.find_peaks(vec, height=np.ma.median(vec), width=3)
if len(peak_idx) > 0:
peak_pos = np.copy(peak_idx).astype(float)
for j in range(len(peak_idx)):
try:
coef = self._fit_single_line(vec, peak_idx[j], line["width"])
peak_pos[j] = coef[1]
except:
peak_pos[j] = np.nan
pass
pos_wave = np.interp(peak_pos, np.arange(high - low), wave[low:high])
residual = np.abs(wl - pos_wave) / wl * speed_of_light
idx = np.argmin(residual)
if residual[idx] < self.threshold:
counter += 1
lines["flag"][i] = True
lines["posm"][i] = low + peak_pos[idx]
logger.info("AutoID identified %i new lines", counter + len(new_lines))
return lines
[docs]
def calculate_residual(self, wave_solution, lines):
"""
Calculate all residuals of all given lines
Residual = (Wavelength Solution - Expected Wavelength) / Expected Wavelength * speed of light
Parameters
----------
wave_solution : array of shape (degree_x, degree_y)
polynomial coefficients of the wavelength solution (in numpy format)
lines : recarray of shape (nlines,)
contains the position of the line on the detector (posm), the order (order), and the expected wavelength (wll)
Returns
-------
residual : array of shape (nlines,)
Residual of each line in m/s
"""
x = lines["posm"]
y = lines["order"]
mask = ~lines["flag"]
solution = self.evaluate_solution(x, y, wave_solution)
residual = (solution - lines["wll"]) / lines["wll"] * speed_of_light
residual = np.ma.masked_array(residual, mask=mask)
return residual
[docs]
def reject_outlier(self, residual, lines):
"""
Reject the strongest outlier
Parameters
----------
residual : array of shape (nlines,)
residuals of all lines
lines : recarray of shape (nlines,)
line data
Returns
-------
lines : struct_array
line data with one more flagged line
residual : array of shape (nlines,)
residuals of each line, with outliers masked (including the new one)
"""
# Strongest outlier
ibad = np.ma.argmax(np.abs(residual))
lines["flag"][ibad] = False
return lines
[docs]
def reject_lines(self, lines, plot=False):
"""
Reject the largest outlier one by one until all residuals are lower than the threshold
Parameters
----------
lines : recarray of shape (nlines,)
Line data with pixel position, and expected wavelength
threshold : float, optional
upper limit for the residual, by default 100
degree : tuple, optional
polynomial degree of the wavelength solution (pixel, column) (default: (6, 6))
plot : bool, optional
Wether to plot the results (default: False)
Returns
-------
lines : recarray of shape (nlines,)
Line data with updated flags
"""
wave_solution = self.fit_wavelengths(lines)
residual = self.calculate_residual(wave_solution, lines)
nbad = 0
while np.ma.any(np.abs(residual) > self.threshold):
lines = self.reject_outlier(residual, lines)
wave_solution = self.fit_wavelengths(lines)
residual = self.calculate_residual(wave_solution, lines)
nbad += 1
logger.info("Discarding %i lines", nbad)
if plot or self.plot >= 3: # pragma: no cover
mask = lines["flag"]
_, axis = plt.subplots()
axis.plot(lines["order"][mask], residual[mask], "X", label="Accepted Lines")
axis.plot(
lines["order"][~mask], residual[~mask], "D", label="Rejected Lines"
)
axis.set_xlabel("Order")
axis.set_ylabel("Residual [m/s]")
axis.set_title("Residuals versus order")
axis.legend()
nrows = max(1, (self.ntrace + 1) // 2)
ncols = min(2, self.ntrace)
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, sharex=True, squeeze=False)
plt.subplots_adjust(hspace=0)
fig.suptitle("Residuals of each order versus image columns")
for idx in range(self.ntrace):
order_lines = lines[lines["order"] == idx]
solution = self.evaluate_solution(
order_lines["posm"], order_lines["order"], wave_solution
)
# Residual in m/s
residual = (
(solution - order_lines["wll"])
/ order_lines["wll"]
* speed_of_light
)
mask = order_lines["flag"]
ax[idx // 2, idx % 2].plot(
order_lines["posm"][mask],
residual[mask],
"X",
label="Accepted Lines",
)
ax[idx // 2, idx % 2].plot(
order_lines["posm"][~mask],
residual[~mask],
"D",
label="Rejected Lines",
)
# ax[idx // 2, idx % 2].tick_params(labelleft=False)
ax[idx // 2, idx % 2].set_ylim(
-self.threshold * 1.5, +self.threshold * 1.5
)
ax[-1, 0].set_xlabel("x [pixel]")
ax[-1, 1].set_xlabel("x [pixel]")
ax[0, 0].legend()
util.show_or_save("wavecal_reject_lines")
return lines
[docs]
def plot_results(self, wave_img, obs):
plt.figure()
plt.subplot(211)
title = "Wavelength solution with Wavelength calibration spectrum\nOrders are in different colours"
if self.plot_title is not None:
title = f"{self.plot_title}\n{title}"
plt.title(title)
plt.xlabel("Wavelength")
plt.ylabel("Observed spectrum")
for i in range(self.ntrace):
plt.plot(wave_img[i], obs[i], label="Order %i" % i)
plt.subplot(212)
plt.title("2D Wavelength solution")
plt.imshow(
wave_img,
aspect="auto",
origin="lower",
extent=(0, self.ncol, 0, self.ntrace),
)
cbar = plt.colorbar()
plt.xlabel("Column")
plt.ylabel("Order")
cbar.set_label("Wavelength [Å]")
util.show_or_save("wavecal_results")
[docs]
def plot_residuals(self, lines, coef, title="Residuals"):
plt.figure()
orders = np.unique(lines["order"])
ntraceers = len(orders)
if self.plot_title is not None:
title = f"{self.plot_title}\n{title}"
plt.suptitle(title)
nplots = int(np.ceil(ntraceers / 2))
for i, order in enumerate(orders):
plt.subplot(nplots, 2, i + 1)
order_lines = lines[lines["order"] == order]
if len(order_lines) > 0:
residual = self.calculate_residual(coef, order_lines)
plt.plot(order_lines["posm"], residual, "rX")
plt.hlines([0], 0, self.ncol)
plt.xlim(0, self.ncol)
plt.ylim(-self.threshold, self.threshold)
if (i + 1) not in [ntraceers, ntraceers - 1]:
plt.xticks([])
else:
plt.xlabel("x [Pixel]")
if (i + 1) % 2 == 0:
plt.yticks([])
# else:
# plt.yticks([-self.threshold, 0, self.threshold])
plt.subplots_adjust(hspace=0, wspace=0.1)
# order = 0
# order_lines = lines[lines["order"] == order]
# if len(order_lines) > 0:
# residual = self.calculate_residual(coef, order_lines)
# plt.plot(order_lines["posm"], residual, "rX")
# plt.hlines([0], 0, self.ncol)
# plt.xlim(0, self.ncol)
# plt.ylim(-self.threshold, self.threshold)
# plt.xlabel("x [Pixel]")
# plt.ylabel("Residual [m/s]")
util.show_or_save("wavecal_residuals")
def _find_peaks(self, comb):
# Find peaks in the comb spectrum
# Run find_peak twice
# once to find the average distance between peaks
# once for real (disregarding close peaks)
c = comb - np.ma.min(comb)
width = self.lfc_peak_width
height = np.ma.median(c)
peaks, _ = signal.find_peaks(c, height=height, width=width)
distance = np.median(np.diff(peaks)) // 4
peaks, _ = signal.find_peaks(c, height=height, distance=distance, width=width)
# Fit peaks with gaussian to get accurate position
new_peaks = peaks.astype(float)
width = np.mean(np.diff(peaks)) // 2
for j, p in enumerate(peaks):
idx = p + np.arange(-width, width + 1, 1)
idx = np.clip(idx, 0, len(c) - 1).astype(int)
try:
coef = util.gaussfit3(np.arange(len(idx)), c[idx])
new_peaks[j] = coef[1] + p - width
except RuntimeError:
new_peaks[j] = p
n = np.arange(len(peaks))
# keep peaks within the range
mask = (new_peaks > 0) & (new_peaks < len(c))
n, new_peaks = n[mask], new_peaks[mask]
return n, new_peaks
[docs]
def calculate_AIC(self, lines, wave_solution):
if self.step_mode:
if self.dimensionality == "1D":
k = 1
for _, v in wave_solution.items():
k += np.size(v[0])
k += np.size(v[1])
elif self.dimensionality == "2D":
k = 1
poly_coef, steps_coef = wave_solution
for _, v in steps_coef.items():
k += np.size(v)
k += np.size(poly_coef)
else:
k = np.size(wave_solution) + 1
# We get the residuals in velocity space
# but need to remove the speed of light component, to get dimensionless parameters
x = lines["posm"]
y = lines["order"]
~lines["flag"]
solution = self.evaluate_solution(x, y, wave_solution)
rss = (solution - lines["wll"]) / lines["wll"]
# rss = self.calculate_residual(wave_solution, lines)
# rss /= speed_of_light
n = rss.size
rss = np.ma.sum(rss**2)
# As per Wikipedia https://en.wikipedia.org/wiki/Akaike_information_criterion
logl = np.log(rss)
aic = 2 * k + n * logl
self.logl = logl
# Guard against division by zero when too few points
if n - k - 1 > 0:
self.aicc = aic + (2 * k**2 + 2 * k) / (n - k - 1)
else:
self.aicc = np.nan
self.aic = aic
return aic
[docs]
def execute(self, obs, lines):
"""
Perform the whole wavelength calibration procedure with the current settings
Parameters
----------
obs : array of shape (ntrace, ncol)
observed image
lines : recarray of shape (nlines,)
reference linelist
Returns
-------
wave_img : array of shape (ntrace, ncol)
Wavelength solution for each pixel
Raises
------
NotImplementedError
If polarimitry flag is set
"""
if self.polarim:
raise NotImplementedError("polarized orders not implemented yet")
self.ntrace, self.ncol = obs.shape
if not isinstance(lines, LineList):
lines = LineList(lines)
if self.atlas_name is not None:
try:
self.atlas = LineAtlas(
self.atlas_name, self.medium, search_dirs=self.atlas_search_dirs
)
except FileNotFoundError:
logger.warning("No atlas file found for %s", self.atlas_name)
self.atlas = None
except:
self.atlas = None
else:
self.atlas = None
obs, lines = self.normalize(obs, lines)
# Step 1: align obs and reference
lines = self.align(obs, lines)
# Keep original positions for reference
lines["posc"] = np.copy(lines["posm"])
# Step 2: Locate the lines on the detector, and update the pixel position
# lines["flag"] = True
lines = self.fit_lines(obs, lines)
for i in range(self.iterations):
logger.info(f"Wavelength calibration iteration: {i}")
# Step 3: Create a wavelength solution on known lines
wave_solution = self.fit_wavelengths(lines)
wave_img = self.make_wave(wave_solution)
# Step 4: Identify lines that fit into the solution
lines = self.auto_id(obs, wave_img, lines)
# Step 5: Reject outliers
lines = self.reject_lines(lines)
# lines = self.reject_lines(lines)
logger.info(
"Number of lines used for wavelength calibration: %i",
np.count_nonzero(lines["flag"]),
)
# Step 6: build final 2d solution
wave_solution = self.fit_wavelengths(lines, plot=self.plot)
wave_img = self.make_wave(wave_solution)
if self.plot:
self.plot_results(wave_img, obs)
aic = self.calculate_AIC(lines, wave_solution)
logger.info("AIC of wavelength fit: %f", aic)
# np.savez("cs_lines.npz", cs_lines=lines.data)
return wave_img, wave_solution, lines
[docs]
class WavelengthCalibrationComb(WavelengthCalibration):
[docs]
def execute(self, comb, wave, lines=None):
self.ntrace, self.ncol = comb.shape
# TODO give everything better names
pixel, order, wavelengths = [], [], []
n_all, f_all = [], []
comb = np.ma.masked_array(comb, mask=comb <= 0)
for i in range(self.ntrace):
# Find Peak positions in current order
n, peaks = self._find_peaks(comb[i])
# Determine the n-offset of this order, relative to the anchor frequency
# Use the existing absolute wavelength calibration as reference
y_ord = np.full(len(peaks), i)
w_old = interp1d(np.arange(len(wave[i])), wave[i], kind="cubic")(peaks)
f_old = speed_of_light / w_old
# fr: repeating frequency
# fd: anchor frequency of this order, needs to be shifted to the absolute reference frame
fr = np.median(np.diff(f_old))
fd = np.median(f_old % fr)
n_raw = (f_old - fd) / fr
n = np.round(n_raw)
if np.any(np.abs(n_raw - n) > 0.3):
logger.warning(
"Bad peaks detected in the frequency comb in order %i", i
)
fr, fd = polyfit(n, f_old, deg=1)
n_offset = 0
# The first order is used as the baseline for all other orders
# The choice is arbitrary and doesn't matter
if i == 0:
f0 = fd
n_offset = 0
else:
# n0: shift in n, relative to the absolute reference
# shift n to the absolute grid, so that all peaks are given by the same f0
n_offset = (f0 - fd) / fr
n_offset = int(round(n_offset))
n -= n_offset
fd += n_offset * fr
n = np.abs(n)
n_all += [n]
f_all += [f_old]
pixel += [peaks]
order += [y_ord]
logger.debug(
"LFC Order: %i, f0: %.3f, fr: %.5f, n0: %.2f", i, fd, fr, n_offset
)
# Here we postualte that m * lambda = const
# where m is the peak number
# this is the result of the grating equation
# at least const is roughly constant for neighbouring peaks
correct = True
if correct:
w_all = [speed_of_light / f for f in f_all]
mw_all = [m * w for m, w in zip(n_all, w_all, strict=False)]
y = np.concatenate(mw_all)
gap = np.median(y)
corr = np.zeros(self.ntrace)
for i in range(self.ntrace):
corri = gap / w_all[i] - n_all[i]
corri = np.median(corri)
corr[i] = np.round(corri)
n_all[i] += corr[i]
logger.debug("LFC order offset correction: %s", corr)
for i in range(self.ntrace):
coef = polyfit(n_all[i], n_all[i] * w_all[i], deg=5)
mw = np.polyval(coef, n_all[i])
w_all[i] = mw / n_all[i]
f_all[i] = speed_of_light / w_all[i]
# Merge Data
n_all = np.concatenate(n_all)
f_all = np.concatenate(f_all)
pixel = np.concatenate(pixel)
order = np.concatenate(order)
# Fit f0 and fr to all data
# (fr, f0), cov = np.polyfit(n_all, f_all, deg=1, cov=True)
fr, f0 = polyfit(n_all, f_all, deg=1)
logger.debug("Laser Frequency Comb Anchor Frequency: %.3f 10**10 Hz", f0)
logger.debug("Laser Frequency Comb Repeating Frequency: %.5f 10**10 Hz", fr)
# All peaks are then given by f0 + n * fr
wavelengths = speed_of_light / (f0 + n_all * fr)
flag = np.full(len(wavelengths), True)
laser_lines = np.rec.fromarrays(
(wavelengths, pixel, pixel, order, flag),
names=("wll", "posm", "posc", "order", "flag"),
)
# Use now better resolution to find the new solution
# A single pass of discarding outliers should be enough
coef = self.fit_wavelengths(laser_lines)
# resid = self.calculate_residual(coef, laser_lines)
# laser_lines["flag"] = np.abs(resid) < self.threshold
# coef = self.fit_wavelengths(laser_lines)
new_wave = self.make_wave(coef)
self.calculate_AIC(laser_lines, coef)
self.n_lines_good = np.count_nonzero(laser_lines["flag"])
logger.info(
f"Laser Frequency Comb solution based on {self.n_lines_good} lines."
)
if self.plot:
residual = wave - new_wave
residual = residual.ravel()
area = np.percentile(residual, (32, 50, 68))
area = area[0] - 5 * (area[1] - area[0]), area[0] + 5 * (area[2] - area[1])
plt.hist(residual, bins=100, range=area)
title = "ThAr - LFC"
if self.plot_title is not None:
title = f"{self.plot_title}\n{title}"
plt.title(title)
plt.xlabel(r"$\Delta\lambda$ [Å]")
plt.ylabel("N")
util.show_or_save("wavecal_lfc_hist")
if self.plot:
if lines is not None:
valid = (lines["order"] >= 0) & (lines["order"] < self.ntrace)
self.plot_residuals(
lines[valid],
coef,
title="GasLamp Line Residuals in the Laser Frequency Comb Solution",
)
self.plot_residuals(
laser_lines,
coef,
title="Laser Frequency Comb Peak Residuals in the LFC Solution",
)
if self.plot:
wave_img = wave
title = "Difference between GasLamp Solution and Laser Frequency Comb solution\nEach plot shows one order"
if self.plot_title is not None:
title = f"{self.plot_title}\n{title}"
plt.suptitle(title)
for i in range(len(new_wave)):
plt.subplot(len(new_wave) // 4 + 1, 4, i + 1)
plt.plot(wave_img[i] - new_wave[i])
util.show_or_save("wavecal_lfc_diff")
if self.plot:
self.plot_results(new_wave, comb)
return coef
[docs]
class WavelengthCalibrationInitialize(WavelengthCalibration):
def __init__(
self,
degree=2,
plot=False,
plot_title="Wavecal Initial",
resid_delta=2000,
match_tolerance=1.0,
iterations=3,
edge_margin=10,
width_min=1.0,
width_max=8.0,
cutoff=0.01,
smoothing=0,
atlas_name="thar",
atlas_search_dirs=None,
medium="vac",
element=None,
):
if element is not None and atlas_name == "thar":
atlas_name = element
super().__init__(
degree=degree,
atlas_name=atlas_name,
atlas_search_dirs=atlas_search_dirs,
medium=medium,
plot=plot,
plot_title=plot_title,
dimensionality="1D",
)
self.resid_delta = resid_delta
self.match_tolerance = match_tolerance
self.iterations = iterations
self.edge_margin = edge_margin
self.width_min = width_min
self.width_max = width_max
self.smoothing = smoothing
self.cutoff = cutoff
[docs]
def get_cutoff(self, spectrum):
if self.cutoff == 0:
cutoff = None
elif self.cutoff < 1:
cutoff = self.cutoff
else:
cutoff = np.nanpercentile(spectrum[spectrum != 0], self.cutoff)
return cutoff
[docs]
def normalize(self, spectrum):
smoothing = self.smoothing
if np.ma.isMaskedArray(spectrum):
spectrum = np.ma.filled(spectrum, 0)
else:
spectrum = np.nan_to_num(spectrum, nan=0, copy=True)
spectrum = spectrum - np.nanmedian(spectrum)
if smoothing != 0:
spectrum = gaussian_filter1d(spectrum, smoothing)
spectrum[spectrum < 0] = 0
smax = np.max(spectrum)
if smax > 0:
spectrum /= smax
return spectrum
[docs]
def identify_lines_for_order(
self, spectrum, atlas, wave_range, order, bundle=-1, is_bundle=False
) -> LineList:
"""Identify spectral lines via iterative peak matching (IDL algorithm).
Detects peaks in the observed spectrum, matches them to atlas lines
using a polynomial wavelength solution, and iteratively refines the
fit with outlier rejection.
Parameters
----------
spectrum : array
observed spectrum for one order
atlas : LineAtlas
reference line atlas
wave_range : 2-tuple
initial wavelength guess (begin, end) in Angstrom
order : int
order index
Returns
-------
LineList
matched lines for this order
"""
label = f"Bundle {bundle}" if is_bundle else f"Order {order}"
spectrum = np.asarray(spectrum)
npix = spectrum.shape[0]
x = np.arange(npix)
spectrum = self.normalize(spectrum)
cutoff = self.get_cutoff(spectrum)
# Detect local maxima and minima
maxima = np.zeros(npix, dtype=bool)
minima = np.zeros(npix, dtype=bool)
for j in range(1, npix - 1):
if spectrum[j] > spectrum[j - 1] and spectrum[j] > spectrum[j + 1]:
maxima[j] = True
if spectrum[j] < spectrum[j - 1] and spectrum[j] < spectrum[j + 1]:
minima[j] = True
peak_idx = np.where(maxima)[0]
min_idx = np.where(minima)[0]
if len(peak_idx) == 0:
logger.warning("%s: no peaks found", label)
return LineList()
# Filter: reject peaks near edges
peak_idx = peak_idx[
(peak_idx >= self.edge_margin) & (peak_idx <= npix - 1 - self.edge_margin)
]
if len(peak_idx) == 0:
return LineList()
# Filter: require at least 3 pixels to the nearest minimum on each side
good = np.ones(len(peak_idx), dtype=bool)
for i, pk in enumerate(peak_idx):
left_mins = min_idx[min_idx < pk]
right_mins = min_idx[min_idx > pk]
if len(left_mins) == 0 or (pk - left_mins[-1]) < 3:
good[i] = False
elif len(right_mins) == 0 or (right_mins[0] - pk) < 3:
good[i] = False
peak_idx = peak_idx[good]
if len(peak_idx) == 0:
return LineList()
# Filter by cutoff threshold
if cutoff is not None:
peak_idx = peak_idx[spectrum[peak_idx] >= cutoff]
if len(peak_idx) == 0:
return LineList()
# Gaussian fit each peak to get sub-pixel position, width, height
posm = np.zeros(len(peak_idx))
widths = np.zeros(len(peak_idx))
heights = np.zeros(len(peak_idx))
fit_ok = np.ones(len(peak_idx), dtype=bool)
for i, pk in enumerate(peak_idx):
left_mins = min_idx[min_idx < pk]
right_mins = min_idx[min_idx > pk]
left = left_mins[-1] if len(left_mins) > 0 else max(0, pk - 5)
right = right_mins[0] if len(right_mins) > 0 else min(npix - 1, pk + 5)
if right - left < 3:
fit_ok[i] = False
continue
seg_x = np.arange(left, right + 1, dtype=float)
seg_y = spectrum[left : right + 1]
try:
popt = util.gaussfit3(seg_x, seg_y)
amp, mu, sig2, offset = popt
sig = np.sqrt(np.abs(sig2))
fwhm = 2.355 * sig
if fwhm < self.width_min or fwhm > self.width_max:
fit_ok[i] = False
continue
posm[i] = mu
widths[i] = fwhm
heights[i] = amp
except (RuntimeError, ValueError):
fit_ok[i] = False
posm = posm[fit_ok]
widths = widths[fit_ok]
heights = heights[fit_ok]
if len(posm) == 0:
logger.warning("%s: no valid peaks after Gaussian fitting", label)
return LineList()
# Atlas lines within the expected range (with margin)
wmin = min(wave_range[0], wave_range[1])
wmax = max(wave_range[0], wave_range[1])
margin = 0.05 * (wmax - wmin)
atlas_waves = atlas.linelist["wave"]
atlas_mask = (atlas_waves > wmin - margin) & (atlas_waves < wmax + margin)
atlas_sub = atlas_waves[atlas_mask]
if len(atlas_sub) == 0:
logger.warning("%s: no atlas lines in range %.1f-%.1f", label, wmin, wmax)
return LineList()
# Step 1: linear wavelength assignment
wlc = wave_range[0] + (wave_range[1] - wave_range[0]) * posm / npix
# Step 2: offset voting - match each atlas line to nearest peak,
# histogram the wavelength offsets to find the true shift
offsets = []
for aw in atlas_sub:
dw = np.abs(wlc - aw)
best = np.argmin(dw)
if dw[best] < self.match_tolerance:
offsets.append(aw - wlc[best])
if len(offsets) < self.degree + 1:
logger.warning("%s: only %d coarse matches", label, len(offsets))
return LineList()
offsets = np.array(offsets)
n_bins = max(10, len(offsets) // 2)
hist, edges = np.histogram(offsets, bins=n_bins)
best_bin = np.argmax(hist)
mode_offset = (edges[best_bin] + edges[best_bin + 1]) / 2
bin_width = edges[1] - edges[0]
near_mode = offsets[np.abs(offsets - mode_offset) < bin_width * 3]
if len(near_mode) >= 3:
wave_offset = np.median(near_mode)
else:
wave_offset = mode_offset
# Apply offset correction
wlc += wave_offset
# Step 3: iterative match-fit-reject using corrected wavelengths
# After voting, the corrected wlc is accurate to ~bin_width,
# so use a tight tolerance for individual matching (like IDL's 0.02A)
tight_tol = max(0.02, bin_width * 2)
best_peak = np.array([])
best_atlas = np.array([])
best_width = np.array([])
best_height = np.array([])
coef = None
for iteration in range(self.iterations):
if iteration > 0:
wlc = np.polyval(coef, posm)
tol = (
tight_tol
if iteration == 0
else (self.resid_delta / speed_of_light * np.median(np.abs(wlc)))
)
# For each atlas line, find the nearest detected peak
matched_peak_l = []
matched_atlas_l = []
matched_width_l = []
matched_height_l = []
used = set()
for aw in atlas_sub:
dw = np.abs(wlc - aw)
best = int(np.argmin(dw))
if dw[best] < tol and best not in used:
matched_peak_l.append(posm[best])
matched_atlas_l.append(aw)
matched_width_l.append(widths[best])
matched_height_l.append(heights[best])
used.add(best)
if len(matched_peak_l) < self.degree + 1:
break
matched_peak = np.array(matched_peak_l)
matched_atlas = np.array(matched_atlas_l)
matched_width = np.array(matched_width_l)
matched_height = np.array(matched_height_l)
# Iterative sigma-clipping: fit, reject worst outliers, refit
for _clip in range(5):
coef = polyfit(matched_peak, matched_atlas, self.degree)
fitted_wave = np.polyval(coef, matched_peak)
resid_vel = (
np.abs(fitted_wave - matched_atlas) / matched_atlas * speed_of_light
)
med = np.median(resid_vel)
mad = np.median(np.abs(resid_vel - med)) * 1.4826
clip_thresh = max(med + 3 * mad, self.resid_delta)
keep = resid_vel < clip_thresh
if np.sum(keep) < self.degree + 1 or np.sum(keep) == len(matched_peak):
break
matched_peak = matched_peak[keep]
matched_atlas = matched_atlas[keep]
matched_width = matched_width[keep]
matched_height = matched_height[keep]
coef = polyfit(matched_peak, matched_atlas, self.degree)
# Save best result so far
best_peak = matched_peak
best_atlas = matched_atlas
best_width = matched_width
best_height = matched_height
matched_peak = best_peak
matched_atlas = best_atlas
matched_width = best_width
matched_height = best_height
save_tag = (
f"wavecal_init_bundle_{bundle}"
if is_bundle
else f"wavecal_init_order_{order}"
)
if len(matched_peak) > 0:
linelist = LineList()
for j in range(len(matched_peak)):
linelist.add_line(
matched_atlas[j],
order,
matched_peak[j],
matched_width[j],
matched_height[j],
True,
bundle=bundle,
)
# Cap per-row plots so single-order multi-bundle runs don't
# spawn dozens of figures (mirrors _fit_single_line's "5 of many").
if not hasattr(self, "_init_plot_count"):
self._init_plot_count = 0
if self.plot:
self._init_plot_count += 1
if self._init_plot_count == 6:
logger.info(
"Skipping remaining wavecal_init plots (shown 5 of many)"
)
if self.plot and self._init_plot_count <= 5:
wave = np.polyval(coef, x)
atlas_flux = np.interp(wave, atlas.wave, atlas.flux)
atlas_flux /= np.max(atlas_flux)
plt.plot(wave, spectrum, label="observed")
plt.plot(wave, atlas_flux, alpha=0.5, label="atlas")
plt.plot(
matched_atlas,
matched_height,
"rx",
markersize=4,
label=f"{len(matched_peak)} matches",
)
title = f"{label} ({self._init_plot_count}/5)"
if self.plot_title:
title = f"{self.plot_title}\n{title}"
plt.title(title)
plt.xlabel("Wavelength [A]")
plt.legend(fontsize="small")
util.show_or_save(save_tag)
rms = np.std(
(np.polyval(coef, matched_peak) - matched_atlas)
/ matched_atlas
* speed_of_light
)
logger.info(
"%s: matched %d lines, rms=%.1f m/s", label, len(matched_peak), rms
)
return linelist
logger.warning("%s: no lines matched", label)
return LineList()
[docs]
def execute(self, spectrum, wave_range, single_order=False) -> LineList:
atlas = LineAtlas(
self.atlas_name, self.medium, search_dirs=self.atlas_search_dirs
)
linelist = LineList()
n_rows = spectrum.shape[0]
# Single-order multi-fiber instruments (e.g. MOSAIC) give one wave_range
# entry but extract many bundles -- reuse the same range for each row.
if len(wave_range) == 1 and n_rows > 1:
logger.info(
"wavelength_range has 1 entry but spectrum has %d rows; broadcasting",
n_rows,
)
wave_range = [wave_range[0]] * n_rows
for order in range(n_rows):
ll = self.identify_lines_for_order(
spectrum[order],
atlas,
wave_range[order],
order,
bundle=order if single_order else -1,
is_bundle=single_order,
)
linelist.append(ll)
return linelist