"""
Trace data model for PyReduce.
This module defines the Trace dataclass and I/O functions for storing
trace positions, curvature, and wavelength calibration in FITS format.
The Trace dataclass consolidates what was previously scattered across
separate files (traces.npz, curve.npz, wavecal.npz) into a single structure.
"""
from __future__ import annotations
import logging
import re
from dataclasses import dataclass
from pathlib import Path
import astropy.io.fits as fits
import numpy as np
logger = logging.getLogger(__name__)
# Format version for backwards compatibility detection
# v2: Initial FITS format with FIBER column
# v3: Renamed FIBER→GROUP, added FIBER_IDX column
# v4: Added BUNDLE column (bundle id, independent of m)
FORMAT_VERSION = 4
[docs]
@dataclass
class Trace:
"""Container for a single trace's geometry and calibration data.
A trace represents a single spectral order (or fiber within an order)
on the detector.
Attributes
----------
m : int | None
Spectral order number (diffraction order). This is the physical order
number from the grating equation, not a sequential index. In echelle
spectrographs, higher order numbers correspond to shorter wavelengths.
The order number is assigned in one of three ways:
1. **From order_centers.yaml** (preferred): If the instrument provides
an ``order_centers_{channel}.yaml`` file with known order positions,
traces are matched to these centers during detection and assigned
the corresponding order numbers immediately.
2. **From wavelength calibration**: If no order_centers file exists,
``m`` is initially None. During wavelength calibration, the linelist
file provides ``obase`` (the base order number). Each trace is then
assigned ``m = obase + trace_index``.
3. **Sequential fallback**: For legacy files or MOSAIC mode where order
identity cannot be determined, ``m`` may remain None or be assigned
sequentially from 0.
The order number is critical for 2D wavelength calibration, which fits
a polynomial in both pixel position (x) and order number (m). When
evaluating wavelengths via ``Trace.wlen()``, the trace's ``m`` value
is used as the second coordinate in the 2D polynomial.
bundle : int | None
Bundle identifier (1-indexed), independent of `m`. Used by
instruments where fibers are organised into spatial bundles
within each spectral order (MOSAIC: 90 bundles × 7 fibers; in
principle also bundled echelle: m orders × bundles). For
instruments without a bundle concept (ANDES groups, single-fiber
echelles), bundle stays None. Drives the bundle group name
``f"bundle_{bundle}"``.
group : str | int | None
Group identifier, or None if trace is ungrouped. When set, indicates
this trace is the result of grouping/merging fibers for this order.
There should be exactly one trace per (m, bundle, group). String
for named groups ('A', 'B', 'cal') or bundle merges
('bundle_45'); int kept for legacy compatibility.
**Mutually exclusive with fiber_idx.** A trace has either:
- group set (merged/grouped result) and fiber_idx=None, or
- fiber_idx set (individual fiber) and group=None, or
- both None (ungrouped single-fiber instrument)
fiber_idx : int | None
Physical fiber index (1-indexed) within (m, bundle). For
multi-fiber instruments where fibers are tracked individually
(not merged). Used for per-fiber wavelength calibration. There
should be exactly one trace per (m, bundle, fiber_idx).
**Mutually exclusive with group.** See group docstring for details.
pos : np.ndarray
y(x) trace position polynomial coefficients, shape (deg+1,).
Coefficients in numpy.polyval order (highest power first).
column_range : tuple[int, int]
Valid x range [start, end) for this trace.
height : float | None
Extraction aperture height in pixels. None to use settings default.
slit : np.ndarray | None
Slit curvature coefficients, shape (deg_y+1, deg_x+1).
Evaluates to x_offset = P(y) where P's coefficients vary with x.
slit[i, :] are coefficients for the y^i term as a function of x.
slitdelta : np.ndarray | None
Per-row slit correction, shape (height_pixels,).
Residual offsets beyond polynomial fit.
wave : np.ndarray | None
Wavelength polynomial coefficients. Can be:
- 1D array, shape (deg+1,): per-trace polynomial, wavelength = polyval(x)
- 2D array, shape (deg_x+1, deg_m+1): global 2D polynomial shared across
all traces. Wavelength = polyval2d(x, m) where m is this trace's order.
"""
# Identity
m: int | None
# Geometry (required)
pos: np.ndarray
column_range: tuple[int, int]
# Optional fields (must come after required fields)
bundle: int | None = None
group: str | int | None = None
fiber_idx: int | None = None
height: float | None = None
slit: np.ndarray | None = None
slitdelta: np.ndarray | None = None
wave: np.ndarray | None = None
_wave_idx: int | None = None # trace index for 2D polynomial evaluation
invalid: str | None = None # reason if trace should be skipped
[docs]
def slit_at_x(self, x: float | np.ndarray) -> np.ndarray | None:
"""Evaluate slit polynomial coefficients at position x.
Parameters
----------
x : float or np.ndarray
Column position(s) to evaluate at.
Returns
-------
np.ndarray or None
Polynomial coefficients for y_offset = c0 + c1*y + c2*y^2 + ...
Shape (deg_y+1,) for scalar x, or (len(x), deg_y+1) for array x.
Returns None if no slit curvature is set.
"""
if self.slit is None:
return None
# slit[i, :] = coefficients for y^i term as function of x
# Evaluate each row's polynomial at x
return np.array([np.polyval(c, x) for c in self.slit])
[docs]
def wlen(self, x: np.ndarray) -> np.ndarray | None:
"""Evaluate wavelength polynomial at column positions.
Parameters
----------
x : np.ndarray
Column positions to evaluate at.
Returns
-------
np.ndarray or None
Wavelength values at each x position.
Returns None if no wavelength calibration is set.
"""
if self.wave is None:
return None
if self.wave.ndim == 2:
# 2D polynomial: wave[i,j] is coeff for x^i * idx^j
# The polynomial is fitted with trace indices (0, 1, 2, ...),
# so we must use _wave_idx (not physical order number m).
idx = self._wave_idx
if idx is None:
logger.warning(
"Cannot evaluate 2D wavelength polynomial: trace._wave_idx is None."
)
return None
m_arr = np.full_like(x, idx, dtype=float)
return np.polynomial.polynomial.polyval2d(x, m_arr, self.wave)
else:
# 1D polynomial: standard polyval
return np.polyval(self.wave, x)
[docs]
def y_at_x(self, x: np.ndarray) -> np.ndarray:
"""Evaluate trace y-position at column positions.
Parameters
----------
x : np.ndarray
Column positions to evaluate at.
Returns
-------
np.ndarray
Y positions of the trace center at each x.
"""
return np.polyval(self.pos, x)
def _validate_traces(traces: list[Trace], context: str = "") -> None:
"""Validate trace list invariants.
Checks:
1. (m, bundle, group) is unique for grouped traces (group is not None)
2. (m, bundle, fiber_idx) is unique for individual fiber traces
3. Traces are ordered by y-position (ascending)
Parameters
----------
traces : list[Trace]
Traces to validate.
context : str
Context for error messages (e.g., file path).
Raises
------
ValueError
If validation fails.
"""
if not traces:
return
# Check that group and fiber_idx are mutually exclusive
for i, t in enumerate(traces):
if t.group is not None and t.fiber_idx is not None:
raise ValueError(
f"Trace {i} has both group={t.group} and fiber_idx={t.fiber_idx}. "
f"These are mutually exclusive: group indicates merged fiber result, "
f"fiber_idx indicates individual fiber{context}"
)
# Check (m, bundle, group) uniqueness for grouped traces, and that
# bundle merges encode their bundle id consistently in both fields
# (group="bundle_5" must mean bundle=5).
seen_group = set()
bundle_pat = re.compile(r"^bundle_(\d+)$")
for t in traces:
if t.group is not None:
key = (t.m, t.bundle, t.group)
if key in seen_group:
raise ValueError(
f"Duplicate (m={t.m}, bundle={t.bundle}, group={t.group}) "
f"in traces{context}"
)
seen_group.add(key)
match = bundle_pat.match(str(t.group))
if match and t.bundle != int(match.group(1)):
raise ValueError(
f"Trace group={t.group!r} does not match bundle={t.bundle} "
f"(m={t.m}){context}"
)
# Check (m, bundle, fiber_idx) uniqueness for fiber traces
seen_fiber = set()
for t in traces:
if t.fiber_idx is not None:
key = (t.m, t.bundle, t.fiber_idx)
if key in seen_fiber:
raise ValueError(
f"Duplicate (m={t.m}, bundle={t.bundle}, "
f"fiber_idx={t.fiber_idx}) in traces{context}"
)
seen_fiber.add(key)
# Check y-position ordering (evaluate at midpoint of column range)
# Only applies to ungrouped traces - grouped traces are organized by (m, group)
has_groups = any(t.group is not None for t in traces)
if not has_groups:
ref_x = (traces[0].column_range[0] + traces[0].column_range[1]) // 2
y_positions = [t.y_at_x(ref_x) for t in traces]
for i in range(1, len(y_positions)):
if y_positions[i] < y_positions[i - 1]:
logger.warning(
"Traces not ordered by y-position at x=%d: trace %d (y=%.1f) < trace %d (y=%.1f)%s",
ref_x,
i,
y_positions[i],
i - 1,
y_positions[i - 1],
context,
)
[docs]
def save_traces(
path: str | Path,
traces: list[Trace],
header: fits.Header = None,
steps: list[str] = None,
) -> None:
"""Save traces to a FITS binary table.
Parameters
----------
path : str | Path
Output file path.
traces : list[Trace]
Traces to save.
header : fits.Header, optional
FITS header to include. If None, a minimal header is created.
steps : list[str], optional
Pipeline steps that have been run (stored in E_STEPS header).
Raises
------
ValueError
If traces have duplicate (group, m) keys.
"""
if not traces:
raise ValueError("Cannot save empty trace list")
_validate_traces(traces, f" when saving to {path}")
if header is None:
header = fits.Header()
else:
header = header.copy()
# Add format metadata
header["E_FMTVER"] = (FORMAT_VERSION, "PyReduce format version")
if steps:
header["E_STEPS"] = (",".join(steps), "Pipeline steps run")
# Determine array sizes
max_pos_deg = max(len(t.pos) for t in traces)
# Determine wave dimensions - can be 1D (per-trace) or 2D (global poly)
wave_shapes = [t.wave.shape if t.wave is not None else () for t in traces]
wave_is_2d = any(len(s) == 2 for s in wave_shapes)
if wave_is_2d:
max_wave_x = max((s[0] if len(s) == 2 else 0) for s in wave_shapes)
max_wave_m = max((s[1] if len(s) == 2 else 0) for s in wave_shapes)
max_wave_deg = 0 # Not used for 2D
else:
max_wave_deg = max((s[0] if len(s) >= 1 else 0) for s in wave_shapes)
max_wave_x = max_wave_m = 0
max_slitdelta_len = max(
(len(t.slitdelta) if t.slitdelta is not None else 0) for t in traces
)
# Determine slit dimensions (deg_y+1, deg_x+1)
slit_shapes = [(t.slit.shape if t.slit is not None else (0, 0)) for t in traces]
max_slit_y = max(s[0] for s in slit_shapes)
max_slit_x = max(s[1] for s in slit_shapes)
ntrace = len(traces)
# Build arrays
m_arr = np.array([t.m if t.m is not None else -1 for t in traces], dtype=np.int16)
bundle_arr = np.array(
[t.bundle if t.bundle is not None else -1 for t in traces], dtype=np.int16
)
group_arr = np.array(
[str(t.group) if t.group is not None else "" for t in traces], dtype="U16"
)
fiber_idx_arr = np.array(
[t.fiber_idx if t.fiber_idx is not None else -1 for t in traces], dtype=np.int16
)
col_range_arr = np.array([t.column_range for t in traces], dtype=np.int32)
height_arr = np.array(
[t.height if t.height is not None else np.nan for t in traces], dtype=np.float32
)
pos_arr = np.zeros((ntrace, max_pos_deg), dtype=np.float64)
for i, t in enumerate(traces):
pos_arr[i, : len(t.pos)] = t.pos
wave_arr = None
if wave_is_2d and max_wave_x > 0 and max_wave_m > 0:
# 2D wavelength polynomial
wave_arr = np.full((ntrace, max_wave_x, max_wave_m), np.nan, dtype=np.float64)
for i, t in enumerate(traces):
if t.wave is not None and t.wave.ndim == 2:
wx, wm = t.wave.shape
wave_arr[i, :wx, :wm] = t.wave
elif max_wave_deg > 0:
# 1D wavelength polynomial per trace
wave_arr = np.full((ntrace, max_wave_deg), np.nan, dtype=np.float64)
for i, t in enumerate(traces):
if t.wave is not None and t.wave.ndim == 1:
wave_arr[i, : len(t.wave)] = t.wave
slit_arr = None
if max_slit_y > 0 and max_slit_x > 0:
slit_arr = np.full((ntrace, max_slit_y, max_slit_x), np.nan, dtype=np.float64)
for i, t in enumerate(traces):
if t.slit is not None:
sy, sx = t.slit.shape
slit_arr[i, :sy, :sx] = t.slit
slitdelta_arr = None
if max_slitdelta_len > 0:
slitdelta_arr = np.full((ntrace, max_slitdelta_len), np.nan, dtype=np.float32)
for i, t in enumerate(traces):
if t.slitdelta is not None:
slitdelta_arr[i, : len(t.slitdelta)] = t.slitdelta
# Build FITS columns
columns = [
fits.Column(name="M", format="I", array=m_arr),
fits.Column(name="BUNDLE", format="I", array=bundle_arr),
fits.Column(name="GROUP", format="16A", array=group_arr),
fits.Column(name="FIBER_IDX", format="I", array=fiber_idx_arr),
fits.Column(name="POS", format=f"{max_pos_deg}D", array=pos_arr),
fits.Column(name="COL_RANGE", format="2J", array=col_range_arr),
fits.Column(name="HEIGHT", format="E", array=height_arr),
]
if slit_arr is not None:
slit_flat = slit_arr.reshape(ntrace, -1)
columns.append(
fits.Column(
name="SLIT",
format=f"{slit_flat.shape[1]}D",
array=slit_flat,
dim=f"({max_slit_x},{max_slit_y})",
)
)
header["SLIT_Y"] = (max_slit_y, "Slit polynomial y-degree + 1")
header["SLIT_X"] = (max_slit_x, "Slit polynomial x-degree + 1")
if slitdelta_arr is not None:
columns.append(
fits.Column(
name="SLITDELTA", format=f"{max_slitdelta_len}E", array=slitdelta_arr
)
)
if wave_arr is not None:
if wave_is_2d:
wave_flat = wave_arr.reshape(ntrace, -1)
columns.append(
fits.Column(
name="WAVE",
format=f"{wave_flat.shape[1]}D",
array=wave_flat,
dim=f"({max_wave_m},{max_wave_x})",
)
)
header["WAVE_X"] = (max_wave_x, "Wave polynomial x-degree + 1")
header["WAVE_M"] = (max_wave_m, "Wave polynomial m-degree + 1")
else:
columns.append(
fits.Column(name="WAVE", format=f"{max_wave_deg}D", array=wave_arr)
)
# Create HDU list
primary = fits.PrimaryHDU(header=header)
table = fits.BinTableHDU.from_columns(columns, name="TRACES")
hdulist = fits.HDUList([primary, table])
hdulist.writeto(path, overwrite=True, output_verify="silentfix+ignore")
logger.info("Saved %d traces to: %s", ntrace, path)
[docs]
def load_traces(path: str | Path) -> tuple[list[Trace], fits.Header]:
"""Load traces from a FITS file.
Also supports loading legacy NPZ format for backwards compatibility.
Parameters
----------
path : str | Path
Input file path (.fits or .npz).
Returns
-------
traces : list[Trace]
Loaded traces.
header : fits.Header
FITS header (empty for NPZ files).
"""
path = Path(path)
if path.suffix == ".npz":
return _load_traces_npz(path)
with fits.open(path, memmap=False) as hdu:
header = hdu[0].header
fmtver = header.get("E_FMTVER", 1)
if fmtver < 2:
logger.warning("Loading traces from old format (version %d)", fmtver)
data = hdu["TRACES"].data
m_arr = data["M"]
bundle_arr = data["BUNDLE"] if "BUNDLE" in data.dtype.names else None
# Handle both new (GROUP) and old (FIBER) column names
if "GROUP" in data.dtype.names:
group_arr = data["GROUP"]
else:
group_arr = data["FIBER"] # Backward compat with v2
fiber_idx_arr = data["FIBER_IDX"] if "FIBER_IDX" in data.dtype.names else None
pos_arr = data["POS"]
col_range_arr = data["COL_RANGE"]
height_arr = data["HEIGHT"]
slit_arr = data["SLIT"] if "SLIT" in data.dtype.names else None
slitdelta_arr = data["SLITDELTA"] if "SLITDELTA" in data.dtype.names else None
wave_arr = data["WAVE"] if "WAVE" in data.dtype.names else None
# Reshape slit if present
if slit_arr is not None:
slit_y = header.get("SLIT_Y", 0)
slit_x = header.get("SLIT_X", 0)
if slit_y > 0 and slit_x > 0:
slit_arr = slit_arr.reshape(-1, slit_y, slit_x)
# Reshape wave if 2D polynomial
wave_is_2d = False
if wave_arr is not None:
wave_x = header.get("WAVE_X", 0)
wave_m = header.get("WAVE_M", 0)
if wave_x > 0 and wave_m > 0:
wave_arr = wave_arr.reshape(-1, wave_x, wave_m)
wave_is_2d = True
traces = []
for i in range(len(m_arr)):
m = int(m_arr[i]) if m_arr[i] >= 0 else None
group = group_arr[i].strip()
# Empty string or "0" means no group (backward compat)
if group == "" or group == "0":
group = None
else:
# Try to convert group to int if it looks like one
try:
group = int(group)
except ValueError:
pass
fiber_idx = (
int(fiber_idx_arr[i])
if fiber_idx_arr is not None and fiber_idx_arr[i] >= 0
else None
)
bundle = (
int(bundle_arr[i])
if bundle_arr is not None and bundle_arr[i] >= 0
else None
)
# Remove trailing NaN/zeros from pos
pos = pos_arr[i]
pos = pos[~np.isnan(pos)] if np.any(np.isnan(pos)) else pos
column_range = (int(col_range_arr[i, 0]), int(col_range_arr[i, 1]))
height = float(height_arr[i]) if not np.isnan(height_arr[i]) else None
slit = None
if slit_arr is not None:
slit = slit_arr[i]
if np.all(np.isnan(slit)):
slit = None
else:
# Remove all-NaN rows/cols
mask_y = ~np.all(np.isnan(slit), axis=1)
mask_x = ~np.all(np.isnan(slit), axis=0)
slit = slit[mask_y][:, mask_x]
slitdelta = None
if slitdelta_arr is not None:
slitdelta = slitdelta_arr[i]
if np.all(np.isnan(slitdelta)):
slitdelta = None
else:
slitdelta = slitdelta[~np.isnan(slitdelta)]
wave = None
if wave_arr is not None:
wave = wave_arr[i]
if np.all(np.isnan(wave)):
wave = None
elif wave_is_2d:
# 2D polynomial - remove all-NaN rows/cols
mask_x = ~np.all(np.isnan(wave), axis=1)
mask_m = ~np.all(np.isnan(wave), axis=0)
wave = wave[mask_x][:, mask_m]
else:
# 1D polynomial - remove trailing NaN
wave = wave[~np.isnan(wave)]
traces.append(
Trace(
m=m,
bundle=bundle,
group=group,
fiber_idx=fiber_idx,
pos=pos,
column_range=column_range,
height=height,
slit=slit,
slitdelta=slitdelta,
wave=wave,
)
)
# Reconstruct _wave_idx for 2D wave polynomials.
# The 2D polynomial was fitted with trace index within group as the
# order coordinate, so we assign sequential indices per group.
if wave_is_2d:
group_counters: dict = {}
for t in traces:
if t.wave is not None and t.wave.ndim == 2:
g = t.group
idx = group_counters.get(g, 0)
t._wave_idx = idx
group_counters[g] = idx + 1
logger.info("Loaded %d traces from: %s", len(traces), path)
_validate_traces(traces, f" loaded from {path}")
return traces, header
def _load_traces_npz(path: Path) -> tuple[list[Trace], fits.Header]:
"""Load traces from legacy NPZ format.
This handles the old format where traces, column_range, and heights
were stored as separate arrays without order/fiber identity.
Parameters
----------
path : Path
Input NPZ file path.
Returns
-------
traces : list[Trace]
Loaded traces (m and fiber assigned sequentially).
header : fits.Header
Empty header.
"""
data = np.load(path, allow_pickle=True)
# Handle old 'orders' key name
if "orders" in data and "traces" not in data:
trace_coeffs = data["orders"]
else:
trace_coeffs = data["traces"]
column_range = data["column_range"]
# Heights may or may not be present
heights = data.get("heights", None)
if heights is not None and heights.ndim == 0:
heights = None
traces = []
for i in range(len(trace_coeffs)):
height = (
float(heights[i])
if heights is not None and not np.isnan(heights[i])
else None
)
traces.append(
Trace(
m=i, # Sequential order number (no identity preserved)
pos=trace_coeffs[i],
column_range=(int(column_range[i, 0]), int(column_range[i, 1])),
height=height,
)
)
logger.info("Loaded %d traces from legacy NPZ: %s", len(traces), path)
_validate_traces(traces, f" loaded from {path}")
return traces, fits.Header()