Source code for pyreduce.extract
"""Module for extracting data from observations
Authors
-------
Version
-------
License
-------
"""
import logging
import os
import time
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.widgets import Button
from scipy.interpolate import interp1d
from tqdm import tqdm
from . import util
from .spectra import Spectrum
from .trace_model import Trace
from .util import make_index
logger = logging.getLogger(__name__)
# Backend selection: set PYREDUCE_USE_CHARSLIT=1 to use charslit.
# Checked at call time so env var changes within a process take effect.
_charslit_mod = None
def _use_charslit():
return os.environ.get("PYREDUCE_USE_CHARSLIT", "0") == "1"
def _use_deltas():
return os.environ.get("PYREDUCE_USE_DELTAS", "1") == "1"
def _get_charslit():
global _charslit_mod
if _charslit_mod is None:
import charslit
_charslit_mod = charslit
return _charslit_mod
from . import cwrappers
def _slitdec_charslit(
img,
ycen,
slitcurve,
slitdeltas,
lambda_sp,
lambda_sf,
osample,
yrange,
maxiter,
gain,
reject_threshold,
preset_slitfunc,
):
"""Call charslit.slitdec and convert results to the expected format.
Parameters
----------
img : array[nrows, ncols]
Input image swath (may be masked array)
ycen : array[ncols]
Trace center positions (fractional)
slitcurve : array[ncols, 6]
Polynomial coefficients for slit curvature (c0..c5)
slitdeltas : array[nrows] or None
Per-row residual offsets
lambda_sp : float
Spectrum smoothing parameter
lambda_sf : float
Slit function smoothing parameter
osample : int
Oversampling factor
yrange : tuple[int, int]
Extraction range (below, above)
maxiter : int
Maximum iterations
gain : float
Detector gain
reject_threshold : float
Outlier rejection threshold in sigma units (passed as kappa to charslit)
preset_slitfunc : array or None
Preset slit function (not supported by charslit yet, ignored)
Returns
-------
sp : array[ncols]
Extracted spectrum
sl : array[nslitf]
Slit function
model : array[nrows, ncols]
Model image
unc : array[ncols]
Spectrum uncertainties
mask : array[nrows, ncols]
Output mask (True = bad)
info : array[5]
[success, chi2, status, niter, delta_x]
"""
nrows, ncols = img.shape
# Get data and mask
mask_in = np.ma.getmaskarray(img)
data = np.ma.getdata(img).astype(np.float64)
data[~np.isfinite(data)] = 0
mask_in = mask_in | ~np.isfinite(data)
# Compute pixel uncertainties (shot noise)
pix_unc = np.abs(data) * gain
np.sqrt(pix_unc, out=pix_unc)
pix_unc[pix_unc < 1] = 1
pix_unc = pix_unc.astype(np.float64)
# Convert mask: numpy (True=bad) -> charslit (0=bad, 1=good)
mask_c = np.where(mask_in, 0, 1).astype(np.uint8)
# Ensure contiguous arrays
data = np.ascontiguousarray(data)
pix_unc = np.ascontiguousarray(pix_unc)
mask_c = np.ascontiguousarray(mask_c)
# charslit expects full ycen and does the integer/fractional split internally
ycen_c = np.ascontiguousarray(ycen.astype(np.float64))
# charslit expects slitcurve of shape (ncols, 6) - coeffs c0..c5
slitcurve_c = np.ascontiguousarray(slitcurve.astype(np.float64))
if slitdeltas is None:
slitdeltas = np.zeros(nrows, dtype=np.float64)
slitdeltas = np.ascontiguousarray(slitdeltas.astype(np.float64))
# Note: preset_slitfunc is not currently supported by charslit
if preset_slitfunc is not None:
logger.debug("preset_slitfunc is not yet supported by charslit, ignoring")
# Call charslit
result = _get_charslit().slitdec(
data,
pix_unc,
mask_c,
ycen_c,
slitcurve_c,
slitdeltas,
osample=osample,
lambda_sP=float(lambda_sp),
lambda_sL=float(lambda_sf),
maxiter=maxiter,
kappa=float(reject_threshold),
)
sp = result["spectrum"]
sl = result["slitfunction"]
model = result["model"]
unc = result["uncertainty"]
return_code = result.get("return_code", 0)
info_arr = result.get("info", np.zeros(5))
# Convert mask back: charslit -> numpy (True=bad)
mask_out = result.get("mask", mask_c)
mask_out = mask_out == 0
# Build info array: charslit returns info as [success, cost, status, iter, delta_x]
if isinstance(info_arr, np.ndarray) and len(info_arr) >= 5:
info = info_arr
else:
info = np.array([float(return_code == 0), 0.0, float(return_code), 0.0, 0.0])
return sp, sl, model, unc, mask_out, info
def _slitdec_cffi(
img,
ycen,
curvature,
lambda_sp,
lambda_sf,
osample,
yrange,
maxiter,
gain,
reject_threshold,
preset_slitfunc,
):
"""Call CFFI slitfunc_curved and return results in the same format as charslit.
This is the legacy extraction backend using the CFFI C extension.
Only supports curvature degrees 1-2 (p1, p2).
"""
# Extract p1, p2 from curvature array
if curvature is not None:
p1 = curvature[:, 1] if curvature.shape[1] > 1 else np.zeros(curvature.shape[0])
p2 = curvature[:, 2] if curvature.shape[1] > 2 else np.zeros(curvature.shape[0])
else:
ncols = len(ycen)
p1 = np.zeros(ncols)
p2 = np.zeros(ncols)
sp, sl, model, unc, mask, info = cwrappers.slitfunc_curved(
img,
ycen,
p1,
p2,
lambda_sp,
lambda_sf,
osample,
yrange,
maxiter=maxiter,
gain=gain,
reject_threshold=reject_threshold,
preset_slitfunc=preset_slitfunc,
)
return sp, sl, model, unc, mask, info
def _ensure_slitcurve(curvature, ncols, n_coeffs=6):
"""Ensure curvature is in the right format for charslit.
Parameters
----------
curvature : array[ncols, n_coeffs] or None
Curvature coefficients for this trace/swath, or None for vertical extraction.
ncols : int
Number of columns (for validation/creation if None).
n_coeffs : int
Number of coefficients (default 6 for charslit).
Returns
-------
slitcurve : array[ncols, n_coeffs]
Polynomial coefficients padded to n_coeffs.
"""
if curvature is None:
return np.zeros((ncols, n_coeffs), dtype=np.float64)
curvature = np.asarray(curvature, dtype=np.float64)
if curvature.shape[1] >= n_coeffs:
return curvature[:, :n_coeffs]
# Pad with zeros
result = np.zeros((ncols, n_coeffs), dtype=np.float64)
result[:, : curvature.shape[1]] = curvature
return result
[docs]
class ProgressPlot: # pragma: no cover
def __init__(self, nrow, ncol, nslitf, title=None):
self.nrow = nrow
self.ncol = ncol
self.nslitf = nslitf
# Setup debug output directory for saving swath data
from pathlib import Path
reduce_data = os.environ.get("REDUCE_DATA", os.path.expanduser("~/REDUCE_DATA"))
self.save_dir = Path(reduce_data) / "debug"
self.save_dir.mkdir(parents=True, exist_ok=True)
self.min_frame_time = float(
os.environ.get("PYREDUCE_PLOT_ANIMATION_SPEED", 0.3)
)
self.last_frame_time = None
plt.ion()
plt.rcParams["figure.raise_window"] = False
self.fig = plt.figure(figsize=(12, 8))
gs = self.fig.add_gridspec(
4,
5,
height_ratios=[1, 1, 1, 1.2],
width_ratios=[0.03, 1, 1, 1, 0.8],
hspace=0.05,
wspace=0.05,
)
# Colorbar axes (left column)
self.ax_cbar_img = self.fig.add_subplot(gs[0:2, 0])
self.ax_cbar_resid = self.fig.add_subplot(gs[2, 0])
# Image panels (stacked vertically with no gaps, no tick labels)
self.ax_obs = self.fig.add_subplot(gs[0, 1:4])
self.ax_obs.set_axis_off()
self.ax_model = self.fig.add_subplot(
gs[1, 1:4], sharex=self.ax_obs, sharey=self.ax_obs
)
self.ax_model.set_axis_off()
self.ax_resid = self.fig.add_subplot(
gs[2, 1:4], sharex=self.ax_obs, sharey=self.ax_obs
)
self.ax_resid.set_axis_off()
# Slit function panel (rightmost column, top 3 rows, rotated axes)
self.ax_slit = self.fig.add_subplot(gs[0:3, 4])
self.ax_slit.set_title("Slit")
self.ax_slit.set_xlabel("contribution")
self.ax_slit.set_ylim((0, nrow))
self.ax_slit.yaxis.set_label_position("right")
self.ax_slit.yaxis.tick_right()
# Spectrum panel (full bottom row)
self.ax_spec = self.fig.add_subplot(gs[3, 1:])
self.ax_spec.set_xlim((0, ncol))
self.title = title
if title is not None:
self.fig.suptitle(title)
# Create image plots
img = np.ones((nrow, ncol))
self.im_obs = self.ax_obs.imshow(img, aspect="auto", origin="lower")
self.im_model = self.ax_model.imshow(img, aspect="auto", origin="lower")
self.im_resid = self.ax_resid.imshow(
np.zeros((nrow, ncol)), aspect="auto", origin="lower", cmap="bwr"
)
# Colorbars in dedicated axes (ticks/labels on left)
self.cbar_img = self.fig.colorbar(
self.im_obs, cax=self.ax_cbar_img, ticklocation="left"
)
self.cbar_resid = self.fig.colorbar(
self.im_resid, cax=self.ax_cbar_resid, ticklocation="left"
)
# Spectrum plot elements (rejected first as background, then good points)
(self.rejected_spec,) = self.ax_spec.plot([], [], ".r", ms=2, alpha=0.2)
(self.good_spec,) = self.ax_spec.plot([], [], ".g", ms=2, alpha=0.2)
(self.line_spec,) = self.ax_spec.plot([], "-k")
# Slit function plot elements (rejected first as background, then good points)
(self.rejected_slit,) = self.ax_slit.plot([], [], ".r", ms=2, alpha=0.2)
(self.good_slit,) = self.ax_slit.plot([], [], ".g", ms=2, alpha=0.2)
(self.line_slit,) = self.ax_slit.plot([], [], "-k", lw=2)
self.paused = False
self.advance_one = False
self.closed = False
self.fig.canvas.mpl_connect("close_event", self._on_close)
ax_slower = self.fig.add_axes([0.30, 0.02, 0.08, 0.04])
ax_faster = self.fig.add_axes([0.39, 0.02, 0.08, 0.04])
ax_pause = self.fig.add_axes([0.48, 0.02, 0.08, 0.04])
ax_step = self.fig.add_axes([0.57, 0.02, 0.08, 0.04])
self.btn_slower = Button(ax_slower, "Slower")
self.btn_faster = Button(ax_faster, "Faster")
self.btn_pause = Button(ax_pause, "Pause")
self.btn_step = Button(ax_step, "Step")
self.btn_slower.on_clicked(self._slower)
self.btn_faster.on_clicked(self._faster)
self.btn_pause.on_clicked(self._toggle_pause)
self.btn_step.on_clicked(self._step)
self.fig.subplots_adjust(bottom=0.12, top=0.92, left=0.05, right=0.92)
self.fig.canvas.draw()
self.fig.canvas.flush_events()
def _slower(self, event=None):
self.min_frame_time = min(2.0, self.min_frame_time * 1.5)
def _faster(self, event=None):
self.min_frame_time = max(0.01, self.min_frame_time / 1.5)
def _toggle_pause(self, event=None):
self.paused = not self.paused
self.btn_pause.label.set_text("Resume" if self.paused else "Pause")
self.fig.canvas.draw()
def _step(self, event=None):
if self.paused:
self.advance_one = True
def _on_close(self, event=None):
self.closed = True
self.paused = False
[docs]
def wait_if_paused(self):
while self.paused and not self.advance_one:
if self.closed:
break
self.fig.canvas.flush_events()
time.sleep(0.05)
self.advance_one = False
[docs]
def plot(
self,
img,
spec,
slitf,
model,
ycen,
input_mask,
output_mask,
trace_idx=0,
left=0,
right=0,
unc=None,
info=None,
swath_idx=0,
save=True,
slitcurve=None,
slitdeltas=None,
):
if self.closed:
return
# Save swath data to debug directory
if save:
outfile = self.save_dir / f"swath_trace{trace_idx}_swath{swath_idx}.npz"
np.savez(
outfile,
swath_img=img,
ycen=ycen,
spec=spec,
slitf=slitf,
model=model,
unc=unc,
input_mask=input_mask,
output_mask=output_mask,
info=info,
slitcurve=slitcurve,
slitdeltas=slitdeltas,
)
img = np.copy(img)
spec = np.copy(spec)
slitf = np.copy(slitf)
ycen = np.copy(ycen)
ny = img.shape[0]
nspec = img.shape[1]
x_spec, y_spec = self.get_spec(img, spec, slitf, ycen, slitcurve, slitdeltas)
x_slit, y_slit = self.get_slitf(img, spec, slitf, ycen, slitcurve, slitdeltas)
ycen = ycen + ny / 2
old = np.linspace(-1, ny, len(slitf)) + 0.5
# Separate rejected (output_mask=True) and good (output_mask=False) pixels
rejected = output_mask.ravel()
good = ~output_mask.ravel()
rej_spec_x = x_spec[rejected]
rej_spec_y = y_spec[rejected]
rej_slit_x = x_slit[rejected]
rej_slit_y = y_slit[rejected]
good_spec_x = x_spec[good]
good_spec_y = y_spec[good]
good_slit_x = x_slit[good]
good_slit_y = y_slit[good]
# Update image data
vmin, vmax = np.percentile(img, [5, 95])
self.im_obs.set_data(img)
self.im_obs.set_clim(vmin, vmax)
# Show masks on model panel: input mask (white), newly rejected (red)
# Masks are drawn first (underneath), model on top as masked array
new_bad = output_mask & ~input_mask
# Create RGBA arrays (transparent where no mask)
input_rgba = np.zeros((*img.shape, 4), dtype=np.float32)
input_rgba[input_mask, :] = [1, 1, 1, 1] # white
new_rgba = np.zeros((*img.shape, 4), dtype=np.float32)
new_rgba[new_bad, :] = [1, 0, 0, 1] # red
# Match extent of underlying model image
extent = self.im_model.get_extent()
if hasattr(self, "_mask_im_new"):
self._mask_im_new.set_data(new_rgba)
self._mask_im_input.set_data(input_rgba)
else:
# Draw masks first (lower zorder), then model on top
self._mask_im_new = self.ax_model.imshow(
new_rgba,
aspect="auto",
origin="lower",
interpolation="nearest",
extent=extent,
zorder=1,
)
self._mask_im_input = self.ax_model.imshow(
input_rgba,
aspect="auto",
origin="lower",
interpolation="nearest",
extent=extent,
zorder=2,
)
# Move model image to top so cursor reads model values
self.im_model.set_zorder(3)
# Model as masked array: transparent where either mask is set
union_mask = input_mask | output_mask
model_masked = np.ma.array(model, mask=union_mask)
self.im_model.set_data(model_masked)
self.im_model.set_clim(vmin, vmax)
resid = img - model
rlim = np.nanpercentile(np.abs(resid), 99)
self.im_resid.set_data(resid)
self.im_resid.set_clim(-rlim, rlim)
# Update spectrum panel
self.rejected_spec.set_xdata(rej_spec_x)
self.rejected_spec.set_ydata(rej_spec_y)
self.good_spec.set_xdata(good_spec_x)
self.good_spec.set_ydata(good_spec_y)
self.line_spec.set_xdata(np.arange(len(spec)))
self.line_spec.set_ydata(spec)
# Update slit function panel (rotated: contribution on x, y-pixel on y)
self.rejected_slit.set_xdata(rej_slit_y)
self.rejected_slit.set_ydata(rej_slit_x)
self.good_slit.set_xdata(good_slit_y)
self.good_slit.set_ydata(good_slit_x)
self.line_slit.set_xdata(slitf)
self.line_slit.set_ydata(old)
self.ax_spec.set_xlim((0, nspec - 1))
spec_middle = spec[5:-5] if len(spec) > 10 else spec
limit = np.nanmax(spec_middle) * 1.1 if len(spec_middle) > 0 else 1.0
if not np.isnan(limit):
self.ax_spec.set_ylim((0, limit))
self.ax_slit.set_ylim((0, ny))
limit = np.nanmax(slitf) * 1.1
if not np.isnan(limit):
self.ax_slit.set_xlim((0, limit))
niter = int(info[3]) if info is not None else 0
title = f"Trace {trace_idx}, Swath {swath_idx}, Columns {left}-{right}, Iter {niter}"
if self.title is not None:
title = f"{self.title}\n{title}"
self.fig.suptitle(title)
self.fig.canvas.draw()
self.fig.canvas.flush_events()
if self.last_frame_time is not None:
elapsed = time.monotonic() - self.last_frame_time
remaining = self.min_frame_time - elapsed
if remaining > 0:
plt.pause(remaining)
self.last_frame_time = time.monotonic()
self.wait_if_paused()
[docs]
def get_spec(self, img, spec, slitf, ycen, slitcurve=None, slitdeltas=None):
"""get the spectrum corrected by the slit function"""
nrow, ncol = img.shape
row_idx, col_idx = np.indices(img.shape)
ycen_frac = ycen - ycen.astype(int)
# Slit position for slit function interpolation
slit_pos = row_idx - ycen_frac + 0.5
old = np.linspace(-1, nrow - 1 + 1, len(slitf))
sf = np.interp(slit_pos, old, slitf)
# Spectrum contribution: image divided by slit function
spec_val = img / sf
# Compute column position, accounting for curvature
col_pos = col_idx.astype(float)
if slitcurve is not None:
# t = offset from trace center within swath
# ycen is the full trace center position (ylow + fractional_part)
t = row_idx - ycen
delta = np.zeros_like(t, dtype=float)
for i in range(1, min(6, slitcurve.shape[1])):
delta += slitcurve[:, i] * (t**i)
if slitdeltas is not None:
delta += slitdeltas[:, np.newaxis]
col_pos = col_pos - delta
return col_pos.ravel(), spec_val.ravel()
[docs]
def get_slitf(self, img, spec, slitf, ycen, slitcurve=None, slitdeltas=None):
"""get the slit function"""
nrow, ncol = img.shape
row_idx, col_idx = np.indices(img.shape)
ycen_frac = ycen - ycen.astype(int)
# Slit position for display
slit_pos = row_idx - ycen_frac + 1
# Compute effective column position for spectrum lookup
col_eff = col_idx.astype(float)
if slitcurve is not None:
# t = offset from trace center within swath
t = row_idx - ycen
delta = np.zeros_like(t, dtype=float)
for i in range(1, min(6, slitcurve.shape[1])):
delta += slitcurve[:, i] * (t**i)
if slitdeltas is not None:
delta += slitdeltas[:, np.newaxis]
col_eff = col_eff - delta
# Handle zeros in spectrum
if np.any(spec == 0):
i = np.arange(len(spec))
try:
spec = interp1d(
i[spec != 0], spec[spec != 0], fill_value="extrapolate"
)(i)
except ValueError:
spec[spec == 0] = np.median(spec)
# Get spectrum value at effective column position (with interpolation)
col_indices = np.arange(len(spec))
spec_at_col = np.interp(col_eff, col_indices, spec)
# Slit function contribution: image divided by spectrum
slitf_val = img / spec_at_col
return slit_pos.ravel(), slitf_val.ravel()
[docs]
class Swath:
def __init__(self, nswath):
self.nswath = nswath
self.spec = [None] * nswath
self.slitf = [None] * nswath
self.model = [None] * nswath
self.unc = [None] * nswath
self.mask = [None] * nswath
self.info = [None] * nswath
def __len__(self):
return self.nswath
def __getitem__(self, key):
return (
self.spec[key],
self.slitf[key],
self.model[key],
self.unc[key],
self.mask[key],
self.info[key],
)
def __setitem__(self, key, value):
self.spec[key] = value[0]
self.slitf[key] = value[1]
self.model[key] = value[2]
self.unc[key] = value[3]
self.mask[key] = value[4]
self.info[key] = value[5]
[docs]
def fix_parameters(xwd, cr, traces, nrow, ncol, ntrace, ignore_column_range=False):
"""Fix extraction width and column range, so that all pixels used are within the image.
I.e. the column range is cut so that the everything is within the image
Parameters
----------
xwd : float
Total extraction height. Split evenly above/below trace. Values below 3 are fractions of trace spacing.
cr : 2-tuple(int), array
Column range, either one value for all traces, or the whole array
traces : array
polynomial coefficients that describe each trace
nrow : int
Number of rows in the image
ncol : int
Number of columns in the image
ntrace : int
Number of traces in the image
ignore_column_range : bool, optional
if true does not change the column range, however this may lead to problems with the extraction, by default False
Returns
-------
xwd : array
fixed extraction width
cr : array
fixed column range
traces : array
the same traces as before
"""
if xwd is None:
xwd = 1.0
if np.isscalar(xwd):
xwd = np.full(ntrace, xwd)
else:
xwd = np.asarray(xwd)
if xwd.ndim == 1:
if len(xwd) != ntrace:
raise ValueError(
f"extraction_height array length {len(xwd)} doesn't match ntrace {ntrace}"
)
else:
raise ValueError("extraction_height must be a scalar or 1D array")
if cr is None:
cr = np.tile([0, ncol], (ntrace, 1))
else:
cr = np.asarray(cr)
if cr.ndim == 1:
cr = np.tile(cr, (ntrace, 1))
traces = np.asarray(traces)
xwd = np.array([xwd[0], *xwd, xwd[-1]])
cr = np.array([cr[0], *cr, cr[-1]])
traces = extend_traces(traces, nrow)
xwd = fix_extraction_height(xwd, traces, cr, ncol)
if not ignore_column_range:
cr, traces = fix_column_range(cr, traces, xwd, nrow, ncol)
traces = traces[1:-1]
xwd = xwd[1:-1]
cr = cr[1:-1]
return xwd, cr, traces
[docs]
def extend_traces(traces, nrow):
"""Extrapolate extra traces above and below the existing ones
Parameters
----------
traces : array[ntrace, degree]
trace polynomial coefficients
nrow : int
number of rows in the image
Returns
-------
traces : array[ntrace + 2, degree]
extended traces
"""
ntrace, ncoef = traces.shape
if ntrace > 1:
trace_low = 2 * traces[0] - traces[1]
trace_high = 2 * traces[-1] - traces[-2]
else:
trace_low = [0 for _ in range(ncoef)]
trace_high = [0 for _ in range(ncoef - 1)] + [nrow]
return np.array([trace_low, *traces, trace_high])
[docs]
def fix_extraction_height(xwd, traces, cr, ncol):
"""Convert fractional extraction height to pixel range.
Fractions (< 2) are multiplied by the minimum distance to neighboring traces.
Parameters
----------
xwd : array[ntrace]
extraction full height per trace
traces : array[ntrace, degree]
trace polynomial coefficients
cr : array[ntrace, 2]
column range to use
ncol : int
number of columns in image
Returns
-------
xwd : array[ntrace]
updated extraction full height in pixels
"""
if not np.all(xwd >= 2):
x = np.arange(ncol)
for i in range(1, len(xwd) - 1):
if xwd[i] < 2:
# Find minimum distance to neighboring traces
min_dist = np.inf
for k in [i - 1, i + 1]:
left = max(cr[[i, k], 0])
right = min(cr[[i, k], 1])
if right < left:
raise ValueError(
f"Check your column ranges. Traces {i} and {k} are weird"
)
current = np.polyval(traces[i], x[left:right])
neighbor = np.polyval(traces[k], x[left:right])
min_dist = min(min_dist, np.min(np.abs(current - neighbor)))
xwd[i] *= min_dist
xwd[0] = xwd[1]
xwd[-1] = xwd[-2]
xwd = np.ceil(xwd).astype(int)
return xwd
[docs]
def validate_traces_for_extraction(
traces: list[Trace],
extraction_height: float | np.ndarray,
nrow: int,
ncol: int,
) -> None:
"""Validate traces and mark invalid ones.
Checks if each trace's extraction aperture fits within the image bounds.
Invalid traces are marked with trace.invalid = "reason".
Parameters
----------
traces : list[Trace]
Trace objects to validate. Modified in-place.
extraction_height : float or array
Extraction height(s) in pixels.
nrow : int
Number of rows in image.
ncol : int
Number of columns in image.
"""
ix = np.arange(ncol)
for i, trace in enumerate(traces):
if trace.invalid:
continue
if isinstance(extraction_height, np.ndarray):
height = extraction_height[i]
elif extraction_height is not None:
height = extraction_height
else:
height = trace.height if trace.height is not None else 0.5
half = height / 2
# Check if extraction aperture stays within image
y_cen = np.polyval(trace.pos, ix)
y_bot = y_cen - half
y_top = y_cen + half
# Find columns where aperture is fully within image
col_start, col_end = trace.column_range
valid_cols = np.where((y_bot >= 0) & (y_top < nrow))[0]
valid_cols = valid_cols[(valid_cols >= col_start) & (valid_cols < col_end)]
if len(valid_cols) == 0:
trace.invalid = f"extraction height {height:.1f}px exceeds image bounds"
logger.warning("Trace %d: %s, marking invalid", i, trace.invalid)
[docs]
def fix_column_range(column_range, traces, extraction_height, nrow, ncol):
"""Fix the column range, so that no pixels outside the image will be accessed (Thus avoiding errors)
Parameters
----------
img : array[nrow, ncol]
image
traces : array[ntrace, degree]
trace polynomial coefficients
extraction_height : array[ntrace]
extraction full height in pixels
column_range : array[ntrace, 2]
current column range
no_clip : bool, optional
if False, new column range will be smaller or equal to current column range, otherwise it can also be larger (default: False)
Returns
-------
column_range : array[ntrace, 2]
updated column range
traces : array[ntrace, degree]
trace polynomial coefficients (may have rows removed if no valid pixels)
"""
ix = np.arange(ncol)
to_remove = []
half = extraction_height / 2
# Loop over non extension traces
for i, trace in zip(range(1, len(traces) - 1), traces[1:-1], strict=False):
# Shift trace up/down by half extraction_height
coeff_bot, coeff_top = np.copy(trace), np.copy(trace)
coeff_bot[-1] -= half[i]
coeff_top[-1] += half[i]
y_bot = np.polyval(coeff_bot, ix) # low edge of arc
y_top = np.polyval(coeff_top, ix) # high edge of arc
# find regions of pixels inside the image
# then use the region that most closely resembles the existing column range (from tracing)
# but clip it to the existing column range (trace polynomials are not well defined outside the original range)
points_in_image = np.where((y_bot >= 0) & (y_top < nrow))[0]
if len(points_in_image) == 0:
# print(y_bot, y_top,nrow, ncol, points_in_image)
logger.warning(
f"No columns are completely within the specified height for trace {i - 1}, removing it."
)
to_remove += [i]
continue
regions = np.where(np.diff(points_in_image) != 1)[0]
regions = [(r, r + 1) for r in regions]
regions = [
points_in_image[0],
*points_in_image[(regions,)].ravel(),
points_in_image[-1],
]
regions = [[regions[i], regions[i + 1] + 1] for i in range(0, len(regions), 2)]
overlap = [
min(reg[1], column_range[i, 1]) - max(reg[0], column_range[i, 0])
for reg in regions
]
iregion = np.argmax(overlap)
column_range[i] = np.clip(
regions[iregion], column_range[i, 0], column_range[i, 1]
)
column_range[0] = column_range[1]
column_range[-1] = column_range[-2]
if to_remove:
column_range = np.delete(column_range, to_remove, axis=0)
traces = np.delete(traces, to_remove, axis=0)
return column_range, traces
[docs]
def make_bins(swath_width, xlow, xhigh, ycen):
"""Create bins for the swathes
Bins are roughly equally sized, have roughly length swath width (if given)
and overlap roughly half-half with each other
Parameters
----------
swath_width : {int, None}
initial value for the swath_width, bins will have roughly that size, but exact value may change
if swath_width is None, determine a good value, from the data
xlow : int
lower bound for x values
xhigh : int
upper bound for x values
ycen : array[ncol]
center of the order trace
Returns
-------
nbin : int
number of bins
bins_start : array[nbin]
left(beginning) side of the bins
bins_end : array[nbin]
right(ending) side of the bins
"""
if swath_width is None:
ncol = len(ycen)
i = np.unique(ycen.astype(int)) # Points of row crossing
# ni = len(i) # This is how many times this order crosses to the next row
if len(i) > 1: # Curved order crosses rows
i = np.sum(i[1:] - i[:-1]) / (len(i) - 1)
nbin = np.clip(
int(np.round(ncol / i)) // 3, 3, 20
) # number of swaths along the order
else: # Perfectly aligned orders
nbin = np.clip(ncol // 400, 3, None) # Still follow the changes in PSF
nbin = nbin * (xhigh - xlow) // ncol # Adjust for the true order length
else:
nbin = np.clip(int(np.round((xhigh - xlow) / swath_width)), 1, None)
bins = np.linspace(xlow, xhigh, 2 * nbin + 1) # boundaries of bins
bins_start = np.ceil(bins[:-2]).astype(int) # beginning of each bin
bins_end = np.floor(bins[2:]).astype(int) # end of each bin
return nbin, bins_start, bins_end
[docs]
def calc_telluric_correction(telluric, img): # pragma: no cover
"""Calculate telluric correction
If set to specific integer larger than 1 is used as the
offset from the order center line. The sky is then estimated by computing
median signal between this offset and the upper/lower limit of the
extraction window.
Parameters
----------
telluric : int
telluric correction parameter
img : array
image of the swath
Returns
-------
tell : array
telluric correction
"""
width, height = img.shape
tel_lim = telluric if telluric > 5 and telluric < height / 2 else min(5, height / 3)
tel = np.sum(img, axis=0)
itel = np.arange(height)
itel = itel[np.abs(itel - height / 2) >= tel_lim]
tel = img[itel, :]
sc = np.zeros(width)
for itel in range(width):
sc[itel] = np.ma.median(tel[itel])
return sc
[docs]
def calc_scatter_correction(scatter, index):
"""Calculate scatter correction
by interpolating between values?
Parameters
----------
scatter : array of shape (degree_x, degree_y)
2D polynomial coefficients of the background scatter
index : tuple (array, array)
indices of the swath within the overall image
Returns
-------
scatter_correction : array of shape (swath_width, swath_height)
correction for scattered light
"""
# The indices in the image are switched
y, x = index
scatter_correction = np.polynomial.polynomial.polyval2d(x, y, scatter)
return scatter_correction
[docs]
def extract_spectrum(
img,
ycen,
yrange,
xrange,
gain=1,
readnoise=0,
lambda_sf=0.1,
lambda_sp=0,
osample=1,
swath_width=None,
maxiter=20,
reject_threshold=6,
telluric=None,
scatter=None,
normalize=False,
threshold=0,
curvature=None,
slitdeltas=None,
plot=False,
plot_title=None,
im_norm=None,
im_ordr=None,
out_spec=None,
out_sunc=None,
out_slitf=None,
out_mask=None,
progress=None,
ord_num=0,
preset_slitfunc=None,
**kwargs,
):
"""
Extract the spectrum of a single order from an image
The order is split into several swathes of roughly swath_width length, which overlap half-half
For each swath a spectrum and slitfunction are extracted
overlapping sections are combined using linear weights (centrum is strongest, falling off to the edges)
Here is the layout for the bins:
::
1st swath 3rd swath 5th swath ...
/============|============|============|============|============|
2nd swath 4th swath 6th swath
|------------|------------|------------|------------|
|.....|
overlap
+ ******* 1
+ *
+ *
* weights (+) previous swath, (*) current swath
* +
* +
* +++++++ 0
Parameters
----------
img : array[nrow, ncol]
observation (or similar)
ycen : array[ncol]
order trace of the current order
yrange : tuple(int, int)
extraction width in pixles, below and above
xrange : tuple(int, int)
columns range to extract (low, high)
gain : float, optional
adu to electron, amplifier gain (default: 1)
readnoise : float, optional
read out noise factor (default: 0)
lambda_sf : float, optional
slit function smoothing parameter, usually very small (default: 0.1)
lambda_sp : int, optional
spectrum smoothing parameter, usually very small (default: 0)
osample : int, optional
oversampling factor, i.e. how many subpixels to create per pixel (default: 1, i.e. no oversampling)
swath_width : int, optional
swath width suggestion, actual width depends also on ncol, see make_bins (default: None, which will determine the width based on the order tracing)
telluric : {float, None}, optional
telluric correction factor (default: None, i.e. no telluric correction)
scatter : {array, None}, optional
background scatter as 2d polynomial coefficients (default: None, no correction)
normalize : bool, optional
whether to create a normalized image. If true, im_norm and im_ordr are used as output (default: False)
threshold : int, optional
threshold for normalization (default: 0)
curvature : array[ncol, n_coeffs], optional
Slit curvature polynomial coefficients for this trace (default: None, i.e. vertical extraction)
slitdeltas : array[nrows_stored], optional
Per-row residual offsets from polynomial curvature fit (default: None).
Will be interpolated to match swath nrows if lengths differ.
plot : bool, optional
wether to plot the progress, plotting will slow down the procedure significantly (default: False)
ord_num : int, optional
current order number, just for plotting (default: 0)
im_norm : array[nrow, ncol], optional
normalized image, only output if normalize is True (default: None)
im_ordr : array[nrow, ncol], optional
image of the order blaze, only output if normalize is True (default: None)
Returns
-------
spec : array[ncol]
extracted spectrum
slitf : array[nslitf]
extracted slitfunction
mask : array[ncol]
mask of the column range to use in the spectrum
unc : array[ncol]
uncertainty on the spectrum
"""
_, ncol = img.shape
ylow, yhigh = yrange
xlow, xhigh = xrange
nslitf = osample * (ylow + yhigh + 2) + 1
# Validate preset_slitfunc size before extraction
if preset_slitfunc is not None and len(preset_slitfunc) != nslitf:
raise ValueError(
f"preset_slitfunc size mismatch: got {len(preset_slitfunc)} elements, "
f"expected {nslitf} for osample={osample}, yrange=({ylow}, {yhigh}). "
f"Ensure norm_flat and extraction use the same extraction_height and osample."
)
# CFFI backend only supports curvature degree <= 2; truncate if needed
if not _use_charslit() and curvature is not None and curvature.shape[1] > 3:
logger.warning(
"curve_degree > 2 requires charslit backend. "
"Truncating to degree 2. Set PYREDUCE_USE_CHARSLIT=1 for full curvature support."
)
curvature = curvature[:, :3]
ycen_int = np.floor(ycen).astype(int)
spec = np.zeros(ncol) if out_spec is None else out_spec
sunc = np.zeros(ncol) if out_sunc is None else out_sunc
mask = np.full(ncol, False) if out_mask is None else out_mask
slitf = np.zeros(nslitf) if out_slitf is None else out_slitf
nbin, bins_start, bins_end = make_bins(swath_width, xlow, xhigh, ycen)
nswath = 2 * nbin - 1
swath = Swath(nswath)
margin = np.zeros((nswath, 2), int)
if normalize:
norm_img = [None] * nswath
norm_model = [None] * nswath
# Perform slit decomposition within each swath stepping through the order with
# half swath width. Spectra for each decomposition are combined with linear weights.
with tqdm(
enumerate(zip(bins_start, bins_end, strict=False)),
total=len(bins_start),
leave=False,
desc="Swath",
) as t:
for ihalf, (ibeg, iend) in t:
logger.debug("Extracting Swath %i, Columns: %i - %i", ihalf, ibeg, iend)
# Cut out swath from image
index = make_index(ycen_int - ylow, ycen_int + yhigh, ibeg, iend)
swath_img = img[index]
# Convert ycen to swath-relative coordinates
# The swath is cut from ycen_int - ylow, so within the swath:
# trace center = ylow + fractional_part(ycen)
swath_ycen_abs = ycen[ibeg:iend]
swath_ycen = ylow + (swath_ycen_abs - np.floor(swath_ycen_abs))
# Corrections
# TODO: what is it even supposed to do?
if telluric is not None: # pragma: no cover
telluric_correction = calc_telluric_correction(telluric, swath_img)
else:
telluric_correction = 0
if scatter is not None:
scatter_correction = calc_scatter_correction(scatter, index)
else:
scatter_correction = 0
swath_img -= scatter_correction + telluric_correction
# Do Slitfunction extraction
swath_ncols = iend - ibeg
swath_nrows = swath_img.shape[0]
swath_curv = curvature[ibeg:iend] if curvature is not None else None
input_mask = np.ma.getmaskarray(swath_img).copy()
# Prepare curvature for both backends and visualization
slitcurve = _ensure_slitcurve(swath_curv, swath_ncols)
if _use_deltas() and slitdeltas is not None and len(slitdeltas) > 0:
# Interpolate slitdeltas to match swath nrows if needed
if len(slitdeltas) == swath_nrows:
swath_slitdeltas = slitdeltas.astype(np.float64)
else:
x_stored = np.linspace(0, 1, len(slitdeltas))
x_swath = np.linspace(0, 1, swath_nrows)
swath_slitdeltas = np.interp(x_swath, x_stored, slitdeltas)
swath_slitdeltas = swath_slitdeltas.astype(np.float64)
else:
swath_slitdeltas = None
if _use_charslit():
charslit_slitdeltas = (
swath_slitdeltas
if swath_slitdeltas is not None
else np.zeros(swath_nrows, dtype=np.float64)
)
swath[ihalf] = _slitdec_charslit(
swath_img,
swath_ycen_abs,
slitcurve,
charslit_slitdeltas,
lambda_sp=lambda_sp,
lambda_sf=lambda_sf,
osample=osample,
yrange=yrange,
maxiter=maxiter,
gain=gain,
reject_threshold=reject_threshold,
preset_slitfunc=preset_slitfunc,
)
else:
swath[ihalf] = _slitdec_cffi(
swath_img,
swath_ycen_abs,
swath_curv,
lambda_sp=lambda_sp,
lambda_sf=lambda_sf,
osample=osample,
yrange=yrange,
maxiter=maxiter,
gain=gain,
reject_threshold=reject_threshold,
preset_slitfunc=preset_slitfunc,
)
t.set_postfix(chi=f"{swath[ihalf][5][1]:1.2f}")
if normalize:
# Save image and model for later
# Use np.divide to avoid divisions by zero
where = swath.model[ihalf] > threshold / gain
norm_img[ihalf] = np.ones_like(swath.model[ihalf])
np.divide(
np.abs(swath_img),
swath.model[ihalf],
where=where,
out=norm_img[ihalf],
)
norm_model[ihalf] = swath.model[ihalf]
if (
plot >= 2
and not np.all(np.isnan(swath_img))
and util.is_interactive_plot_mode()
): # pragma: no cover
if progress is None:
progress = ProgressPlot(
swath_img.shape[0], swath_img.shape[1], nslitf, title=plot_title
)
progress.plot(
swath_img,
swath.spec[ihalf],
swath.slitf[ihalf],
swath.model[ihalf],
swath_ycen,
input_mask,
swath.mask[ihalf],
ord_num,
ibeg,
iend,
swath.unc[ihalf],
swath.info[ihalf],
ihalf,
slitcurve=slitcurve,
slitdeltas=swath_slitdeltas,
)
# Remove points at the border of the each swath, if order has curvature
# as those pixels have bad information
for i in range(nswath):
margin[i, :] = int(swath.info[i][4]) + 1
# Weight for combining swaths
weight = [np.ones(bins_end[i] - bins_start[i]) for i in range(nswath)]
weight[0][: margin[0, 0]] = 0
weight[-1][len(weight[-1]) - margin[-1, 1] :] = 0
for i, j in zip(range(0, nswath - 1), range(1, nswath), strict=False):
width = bins_end[i] - bins_start[i]
overlap = bins_end[i] - bins_start[j]
# Start and end indices for the two swaths
start_i = width - overlap + margin[j, 0]
end_i = width - margin[i, 1]
start_j = margin[j, 0]
end_j = overlap - margin[i, 1]
# Weights for one overlap from 0 to 1, but do not include those values (whats the point?)
triangle = np.linspace(0, 1, overlap + 1, endpoint=False)[1:]
# Cut away the margins at the corners
triangle = triangle[margin[j, 0] : len(triangle) - margin[i, 1]]
# Set values
weight[i][start_i:end_i] = 1 - triangle
weight[j][start_j:end_j] = triangle
# Don't use the pixels at the egdes (due to curvature)
weight[i][end_i:] = 0
weight[j][:start_j] = 0
# Update column range
xrange[0] += margin[0, 0]
xrange[1] -= margin[-1, 1]
mask[: xrange[0]] = True
mask[xrange[1] :] = True
# Apply weights
for i, (ibeg, iend) in enumerate(zip(bins_start, bins_end, strict=False)):
spec[ibeg:iend] += swath.spec[i] * weight[i]
sunc[ibeg:iend] += swath.unc[i] * weight[i]
if normalize:
for i, (ibeg, iend) in enumerate(zip(bins_start, bins_end, strict=False)):
index = make_index(ycen_int - ylow, ycen_int + yhigh, ibeg, iend)
im_norm[index] += norm_img[i] * weight[i]
im_ordr[index] += norm_model[i] * weight[i]
slitf[:] = np.mean(swath.slitf, axis=0)
sunc[:] = np.sqrt(sunc**2 + (readnoise / gain) ** 2)
return spec, slitf, mask, sunc
[docs]
def get_y_scale(ycen, xrange, extraction_height, nrow):
"""Calculate the y limits of the order for C extraction code.
Parameters
----------
ycen : array[ncol]
order trace
xrange : tuple(int, int)
column range
extraction_height : int
extraction full height in pixels
nrow : int
number of rows in the image, defines upper edge
Returns
-------
ylow, yhigh : int, int
lower and upper y bound for extraction (pixels below/above trace)
These satisfy: ylow + yhigh + 1 = extraction_height
"""
ycen = ycen[xrange[0] : xrange[1]].copy()
half = extraction_height // 2
if extraction_height % 2 == 0:
ycen += 1
else:
ycen += 0.5
ymin = np.floor(ycen - half).astype(int)
if min(ymin) < 0:
ymin = ymin - min(ymin) # help for orders at edge
if max(ymin) >= nrow:
ymin = ymin - max(ymin) + nrow - 1 # helps at edge
ymax = ymin + extraction_height - 1
if max(ymax) >= nrow:
ymax = ymax - max(ymax) + nrow - 1 # helps at edge
ymin = ymax - extraction_height + 1
ylow = int(np.min(ycen - ymin)) # Pixels below center line
yhigh = extraction_height - 1 - ylow # Guarantee total = extraction_height
return ylow, yhigh
[docs]
def optimal_extraction(
img,
traces,
extraction_height,
column_range,
curvature=None,
slitdeltas=None,
plot=False,
plot_title=None,
**kwargs,
):
"""Use optimal extraction to get spectra
This functions just loops over the traces, the actual work is done in extract_spectrum
Parameters
----------
img : array[nrow, ncol]
image to extract
traces : array[ntrace, degree]
trace polynomial coefficients
extraction_height : array[ntrace]
extraction full height in pixels
column_range : array[ntrace, 2]
column range to use
curvature : array[ntrace, ncol, n_coeffs] or None
Slit curvature polynomial coefficients (default: None for vertical extraction)
slitdeltas : array[ntrace, nrows] or None
Per-row residual offsets from curvature fit (default: None)
**kwargs
other parameters for the extraction (see extract_spectrum)
Returns
-------
spectrum : array[ntrace, ncol]
extracted spectrum
slitfunction : array[ntrace, nslitf]
recovered slitfunction
uncertainties: array[ntrace, ncol]
uncertainties on the spectrum
"""
logger.info("Using optimal extraction to produce spectrum")
nrow, ncol = img.shape
ntrace = len(traces)
spectrum = np.zeros((ntrace, ncol))
uncertainties = np.zeros((ntrace, ncol))
slitfunction = [None for _ in range(ntrace)]
# Handle preset_slitfunc (list of per-trace slitfuncs)
preset_slitfunc = kwargs.pop("preset_slitfunc", None)
# Add mask as defined by column ranges
mask = np.full((ntrace, ncol), True)
for i in range(ntrace):
mask[i, column_range[i, 0] : column_range[i, 1]] = False
spectrum = np.ma.array(spectrum, mask=mask)
uncertainties = np.ma.array(uncertainties, mask=mask)
ix = np.arange(ncol)
if plot >= 2 and util.is_interactive_plot_mode(): # pragma: no cover
ncol_swath = kwargs.get("swath_width", img.shape[1] // 400)
nrow_swath = np.max(extraction_height)
nslitf_swath = (nrow_swath + 2) * kwargs.get("osample", 1) + 1
progress = ProgressPlot(nrow_swath, ncol_swath, nslitf_swath, title=plot_title)
else:
progress = None
for i in tqdm(range(ntrace), desc="Trace"):
logger.debug("Extracting trace %i out of %i", i + 1, ntrace)
# Define a fixed height area containing one trace
ycen = np.polyval(traces[i], ix)
yrange = get_y_scale(ycen, column_range[i], extraction_height[i], nrow)
# Shift ycen so floor() rounds to nearest integer, centering the trace
# in the extraction window (sub-pixel offset near 0 instead of biased to +1)
cr = column_range[i]
ycen[cr[0] : cr[1]] += 0.5 if extraction_height[i] % 2 == 1 else 1
osample = kwargs.get("osample", 1)
slitfunction[i] = np.zeros(osample * (sum(yrange) + 2) + 1)
# Return values are set by reference, as the out parameters
# Also column_range is adjusted depending on the curvature
# This is to avoid large chunks of memory of essentially duplicates
order_slitfunc = None
if preset_slitfunc is not None and i < len(preset_slitfunc):
order_slitfunc = preset_slitfunc[i]
trace_curv = curvature[i] if curvature is not None else None
trace_slitdeltas = slitdeltas[i] if slitdeltas is not None else None
extract_spectrum(
img,
ycen,
yrange,
column_range[i],
curvature=trace_curv,
slitdeltas=trace_slitdeltas,
out_spec=spectrum[i],
out_sunc=uncertainties[i],
out_slitf=slitfunction[i],
out_mask=mask[i],
progress=progress,
ord_num=i + 1,
plot=plot,
plot_title=plot_title,
preset_slitfunc=order_slitfunc,
**kwargs,
)
if plot >= 2 and progress is not None: # pragma: no cover
progress.close()
if plot: # pragma: no cover
plot_comparison(
img,
traces,
spectrum,
slitfunction,
extraction_height,
column_range,
title=plot_title,
)
return spectrum, slitfunction, uncertainties
[docs]
def correct_for_curvature(img_order, curvature, xwd, inverse=False):
"""Correct image for slit curvature by interpolation.
Parameters
----------
img_order : array[nrow, ncol]
Image swath to correct
curvature : array[ncol, n_coeffs]
Curvature coefficients [c0, c1, c2, ...] where dx = c1*y + c2*y^2 + ...
xwd : int
Extraction full height in pixels
inverse : bool
If True, apply inverse correction (for model reapplication)
Returns
-------
img_order : array
Corrected image
"""
mask = ~np.ma.getmaskarray(img_order)
sign = -1 if inverse else 1
half = xwd // 2
xt = np.arange(img_order.shape[1])
for y, yt in zip(range(xwd), range(-half, xwd - half), strict=False):
# Compute displacement: dx = c1*y + c2*y^2 + c3*y^3 + ...
dx = np.zeros(img_order.shape[1])
for k in range(1, curvature.shape[1]):
dx += curvature[:, k] * (yt**k)
xi = xt + sign * dx
img_order[y] = np.interp(
xi, xt[mask[y]], img_order[y][mask[y]], left=0, right=0
)
xt = np.arange(img_order.shape[0])
for x in range(img_order.shape[1]):
img_order[:, x] = np.interp(
xt, xt[mask[:, x]], img_order[:, x][mask[:, x]], left=0, right=0
)
return img_order
[docs]
def simple_extraction(
img,
traces,
extraction_height,
column_range,
gain=1,
readnoise=0,
dark=0,
plot=False,
plot_title=None,
curvature=None,
collapse_function="median",
**kwargs,
):
"""Use simple extraction to get a spectrum
Simple extraction takes the sum/mean/median orthogonal to the trace for extraction_height pixels
This extraction makes a few rough assumptions and does not provide the most accurate results,
but rather a good approximation
Parameters
----------
img : array[nrow, ncol]
image to extract
traces : array[ntrace, degree]
trace polynomial coefficients
extraction_height : array[ntrace]
extraction full height in pixels
column_range : array[ntrace, 2]
column range to use
gain : float, optional
adu to electron, amplifier gain (default: 1)
readnoise : float, optional
read out noise (default: 0)
dark : float, optional
dark current noise (default: 0)
plot : bool, optional
wether to plot the results (default: False)
Returns
-------
spectrum : array[ntrace, ncol]
extracted spectrum
uncertainties : array[ntrace, ncol]
uncertainties on extracted spectrum
"""
logger.info("Using simple extraction to produce spectrum")
_, ncol = img.shape
ntrace, _ = traces.shape
spectrum = np.zeros((ntrace, ncol))
uncertainties = np.zeros((ntrace, ncol))
# Add mask as defined by column ranges
mask = np.full((ntrace, ncol), True)
for i in range(ntrace):
mask[i, column_range[i, 0] : column_range[i, 1]] = False
spectrum = np.ma.array(spectrum, mask=mask)
uncertainties = np.ma.array(uncertainties, mask=mask)
x = np.arange(ncol)
for i in tqdm(range(ntrace), desc="Trace"):
logger.debug("Extracting trace %i out of %i", i + 1, ntrace)
x_left_lim = column_range[i, 0]
x_right_lim = column_range[i, 1]
# Rectify the image, i.e. remove the shape of the trace
# Then the center of the trace is within one pixel variations
ycen = np.polyval(traces[i], x).astype(int)
half = extraction_height[i] // 2
yb = ycen - half
yt = yb + extraction_height[i] - 1
index = make_index(yb, yt, x_left_lim, x_right_lim)
img_trace = img[index]
# Correct for curvature
# For each row of the rectified trace, interpolate onto the shifted row
# Masked pixels are set to 0, similar to the summation
if curvature is not None:
trace_curv = curvature[i, x_left_lim:x_right_lim]
img_trace = correct_for_curvature(
img_trace,
trace_curv,
extraction_height[i],
)
# Sum over the prepared image
if collapse_function == "sum":
arc = np.ma.sum(img_trace, axis=0)
elif collapse_function == "mean":
arc = np.ma.mean(img_trace, axis=0) * img_trace.shape[0]
elif collapse_function == "median":
arc = np.ma.median(img_trace, axis=0) * img_trace.shape[0]
else:
raise ValueError(
f"Could not determine the arc method, expected one of ('sum', 'mean', 'median'), but got {collapse_function}"
)
# Store results
spectrum[i, x_left_lim:x_right_lim] = arc
uncertainties[i, x_left_lim:x_right_lim] = (
np.sqrt(np.abs(arc * gain + dark + readnoise**2)) / gain
)
if plot: # pragma: no cover
plot_comparison(
img,
traces,
spectrum,
None,
extraction_height,
column_range,
title=plot_title,
)
return spectrum, uncertainties
[docs]
def plot_comparison(
original, traces, spectrum, slitf, extraction_height, column_range, title=None
): # pragma: no cover
plt.figure()
nrow, ncol = original.shape
ntrace = len(traces)
output = np.zeros((np.sum(extraction_height) + ntrace, ncol))
pos = [0]
x = np.arange(ncol)
for i in range(ntrace):
ycen = np.polyval(traces[i], x)
half = extraction_height[i] // 2
yb = ycen - half
yt = yb + extraction_height[i] - 1
xl, xr = column_range[i]
index = make_index(yb, yt, xl, xr)
yl = pos[i]
yr = pos[i] + index[0].shape[0]
output[yl:yr, xl:xr] = original[index]
vmin, vmax = np.percentile(output[yl:yr, xl:xr], (5, 95))
output[yl:yr, xl:xr] = np.clip(output[yl:yr, xl:xr], vmin, vmax)
output[yl:yr, xl:xr] -= vmin
output[yl:yr, xl:xr] /= vmax - vmin
pos += [yr]
plt.imshow(output, origin="lower", aspect="auto")
for i in range(ntrace):
try:
tmp = spectrum[i, column_range[i, 0] : column_range[i, 1]]
# if len(tmp)
vmin = np.min(tmp[tmp != 0])
tmp = np.copy(spectrum[i])
tmp[tmp != 0] -= vmin
np.log(tmp, out=tmp, where=tmp > 0)
tmp = tmp / np.max(tmp) * 0.9 * (pos[i + 1] - pos[i])
tmp += pos[i]
tmp[tmp < pos[i]] = pos[i]
plt.plot(x, tmp, "r")
except:
pass
locs = np.asarray(extraction_height) + 1
locs = np.array([0, *np.cumsum(locs)[:-1]])
locs[:-1] += (np.diff(locs) * 0.5).astype(int)
locs[-1] += ((output.shape[0] - locs[-1]) * 0.5).astype(int)
plt.yticks(locs, range(len(locs)))
plot_title = "Extracted Spectrum vs. Rectified Image"
if title is not None:
plot_title = f"{title}\n{plot_title}"
plt.title(plot_title)
plt.xlabel("x [pixel]")
plt.ylabel("trace")
util.show_or_save("extract_rectify")
[docs]
def extract(
img,
traces: list[Trace],
extraction_height: float = 0.5,
extraction_type: str = "optimal",
**kwargs,
) -> list[Spectrum]:
"""
Extract spectra from an image.
Parameters
----------
img : array[nrow, ncol](float)
Observation to extract.
traces : list[Trace]
Trace objects with position, column_range, and optional slit curvature.
extraction_height : float, optional
Extraction height. Values below 2 are fractions of trace spacing,
values above are pixels. If None, falls back to trace.height. (default: 0.5)
extraction_type : {"optimal", "simple"}, optional
Extraction algorithm. (default: "optimal")
**kwargs
Additional parameters for extraction functions (osample, lambda_sf, etc.)
Returns
-------
list[Spectrum]
Extracted spectrum objects, one per trace.
"""
if len(traces) == 0:
return []
nrow, ncol = img.shape
# Validate traces and mark invalid ones
validate_traces_for_extraction(traces, extraction_height, nrow, ncol)
# Filter to valid traces only
valid_traces = [t for t in traces if not t.invalid]
if len(valid_traces) == 0:
logger.warning("No valid traces remaining after validation")
return []
ntrace = len(valid_traces)
# Convert Trace objects to arrays for internal processing
traces_arr = np.array([t.pos for t in valid_traces])
column_range = np.array(
[list(t.column_range) for t in valid_traces], dtype=np.int32
)
# Settings value takes precedence; fall back to trace.height when None
if extraction_height is not None:
heights = np.full(ntrace, extraction_height)
else:
heights = np.array(
[t.height if t.height is not None else 0.5 for t in valid_traces]
)
# Build curvature arrays from Trace.slit
curvature = None
if any(t.slit is not None for t in valid_traces):
# Get max slit dimensions
max_y = max(
(t.slit.shape[0] if t.slit is not None else 0) for t in valid_traces
)
max_x = max(
(t.slit.shape[1] if t.slit is not None else 0) for t in valid_traces
)
if max_y > 0 and max_x > 0:
curvature = np.zeros((ntrace, ncol, max_y), dtype=np.float64)
for i, t in enumerate(valid_traces):
if t.slit is not None:
# Evaluate slit polynomial at each column
for j in range(ncol):
coeffs = t.slit_at_x(j)
if coeffs is not None:
curvature[i, j, : len(coeffs)] = coeffs
# Build slitdeltas array from Trace.slitdelta
slitdeltas = None
if any(t.slitdelta is not None for t in valid_traces):
max_len = max(
(len(t.slitdelta) if t.slitdelta is not None else 0) for t in valid_traces
)
if max_len > 0:
slitdeltas = np.zeros((ntrace, max_len), dtype=np.float32)
for i, t in enumerate(valid_traces):
if t.slitdelta is not None:
slitdeltas[i, : len(t.slitdelta)] = t.slitdelta
# Fix the input parameters
heights, column_range, traces_arr = fix_parameters(
heights, column_range, traces_arr, nrow, ncol, ntrace
)
# Perform extraction
if extraction_type == "optimal":
spectrum, slitfunction, uncertainties = optimal_extraction(
img,
traces_arr,
heights,
column_range,
curvature=curvature,
slitdeltas=slitdeltas,
**kwargs,
)
elif extraction_type in ("simple", "arc"):
spectrum, uncertainties = simple_extraction(
img,
traces_arr,
heights,
column_range,
curvature=curvature,
**kwargs,
)
slitfunction = [None] * ntrace
else:
raise ValueError(
f"extraction_type must be 'optimal' or 'simple', got {extraction_type}"
)
# Convert results to Spectrum objects
results = []
for i, trace in enumerate(valid_traces):
# Convert masked array to NaN-masked regular array
spec_1d = np.ma.filled(spectrum[i], np.nan)
unc_1d = np.ma.filled(uncertainties[i], np.nan)
slitfu = slitfunction[i] if slitfunction else None
# Evaluate wavelength from trace polynomial if available
wave = trace.wlen(np.arange(len(spec_1d)))
results.append(
Spectrum.from_trace(
trace,
spec_1d,
unc_1d,
wave=wave,
slitfu=slitfu,
extraction_height=heights[i],
)
)
return results
[docs]
def extract_normalize(
img,
traces: list[Trace],
extraction_height: float = 0.5,
**kwargs,
):
"""
Extract and normalize flat field image.
This is a specialized extraction mode for flat field processing that
returns normalized images rather than Spectrum objects.
Parameters
----------
img : array[nrow, ncol](float)
Flat field image to normalize.
traces : list[Trace]
Trace objects with position, column_range, height, slit curvature.
extraction_height : float, optional
Extraction height. If None, falls back to trace.height.
**kwargs
Additional parameters for extraction.
Returns
-------
im_norm : array[nrow, ncol](float)
Normalized flat field image.
im_ordr : array[nrow, ncol](float)
Image with just the trace regions.
blaze : array[ntrace, ncol](float)
Extracted blaze function.
slitfunction : list
Recovered slit functions.
column_range : array[ntrace, 2](int)
Column ranges used.
"""
if not traces:
raise ValueError("No traces provided")
nrow, ncol = img.shape
ntrace = len(traces)
# Convert Trace objects to arrays
traces_arr = np.array([t.pos for t in traces])
column_range = np.array([list(t.column_range) for t in traces], dtype=np.int32)
# Settings value takes precedence; fall back to trace.height when None
if extraction_height is not None:
heights = np.full(ntrace, extraction_height)
else:
heights = np.array([t.height if t.height is not None else 0.5 for t in traces])
# Build curvature arrays
curvature = None
if any(t.slit is not None for t in traces):
max_y = max((t.slit.shape[0] if t.slit is not None else 0) for t in traces)
max_x = max((t.slit.shape[1] if t.slit is not None else 0) for t in traces)
if max_y > 0 and max_x > 0:
curvature = np.zeros((ntrace, ncol, max_y), dtype=np.float64)
for i, t in enumerate(traces):
if t.slit is not None:
for j in range(ncol):
coeffs = t.slit_at_x(j)
if coeffs is not None:
curvature[i, j, : len(coeffs)] = coeffs
# Build slitdeltas array
slitdeltas = None
if any(t.slitdelta is not None for t in traces):
max_len = max(
(len(t.slitdelta) if t.slitdelta is not None else 0) for t in traces
)
if max_len > 0:
slitdeltas = np.zeros((ntrace, max_len), dtype=np.float32)
for i, t in enumerate(traces):
if t.slitdelta is not None:
slitdeltas[i, : len(t.slitdelta)] = t.slitdelta
# Fix parameters
heights, column_range, traces_arr = fix_parameters(
heights, column_range, traces_arr, nrow, ncol, ntrace
)
# Prepare output images
im_norm = np.zeros_like(img)
im_ordr = np.zeros_like(img)
blaze, slitfunction, _ = optimal_extraction(
img,
traces_arr,
heights,
column_range,
curvature=curvature,
slitdeltas=slitdeltas,
normalize=True,
im_norm=im_norm,
im_ordr=im_ordr,
**kwargs,
)
threshold_lower = kwargs.get("threshold_lower", 0)
im_norm[im_norm <= threshold_lower] = 1
im_ordr[im_ordr <= threshold_lower] = 1
return im_norm, im_ordr, blaze, slitfunction, column_range