Source code for pyreduce.trace_model

"""
Trace data model for PyReduce.

This module defines the Trace dataclass and I/O functions for storing
trace positions, curvature, and wavelength calibration in FITS format.

The Trace dataclass consolidates what was previously scattered across
separate files (traces.npz, curve.npz, wavecal.npz) into a single structure.
"""

from __future__ import annotations

import logging
import re
from dataclasses import dataclass
from pathlib import Path

import astropy.io.fits as fits
import numpy as np

logger = logging.getLogger(__name__)

# Format version for backwards compatibility detection
# v2: Initial FITS format with FIBER column
# v3: Renamed FIBER→GROUP, added FIBER_IDX column
# v4: Added BUNDLE column (bundle id, independent of m)
FORMAT_VERSION = 4


[docs] @dataclass class Trace: """Container for a single trace's geometry and calibration data. A trace represents a single spectral order (or fiber within an order) on the detector. Attributes ---------- m : int | None Spectral order number (diffraction order). This is the physical order number from the grating equation, not a sequential index. In echelle spectrographs, higher order numbers correspond to shorter wavelengths. The order number is assigned in one of three ways: 1. **From order_centers.yaml** (preferred): If the instrument provides an ``order_centers_{channel}.yaml`` file with known order positions, traces are matched to these centers during detection and assigned the corresponding order numbers immediately. 2. **From wavelength calibration**: If no order_centers file exists, ``m`` is initially None. During wavelength calibration, the linelist file provides ``obase`` (the base order number). Each trace is then assigned ``m = obase + trace_index``. 3. **Sequential fallback**: For legacy files or MOSAIC mode where order identity cannot be determined, ``m`` may remain None or be assigned sequentially from 0. The order number is critical for 2D wavelength calibration, which fits a polynomial in both pixel position (x) and order number (m). When evaluating wavelengths via ``Trace.wlen()``, the trace's ``m`` value is used as the second coordinate in the 2D polynomial. bundle : int | None Bundle identifier (1-indexed), independent of `m`. Used by instruments where fibers are organised into spatial bundles within each spectral order (MOSAIC: 90 bundles × 7 fibers; in principle also bundled echelle: m orders × bundles). For instruments without a bundle concept (ANDES groups, single-fiber echelles), bundle stays None. Drives the bundle group name ``f"bundle_{bundle}"``. group : str | int | None Group identifier, or None if trace is ungrouped. When set, indicates this trace is the result of grouping/merging fibers for this order. There should be exactly one trace per (m, bundle, group). String for named groups ('A', 'B', 'cal') or bundle merges ('bundle_45'); int kept for legacy compatibility. **Mutually exclusive with fiber_idx.** A trace has either: - group set (merged/grouped result) and fiber_idx=None, or - fiber_idx set (individual fiber) and group=None, or - both None (ungrouped single-fiber instrument) fiber_idx : int | None Physical fiber index (1-indexed) within (m, bundle). For multi-fiber instruments where fibers are tracked individually (not merged). Used for per-fiber wavelength calibration. There should be exactly one trace per (m, bundle, fiber_idx). **Mutually exclusive with group.** See group docstring for details. pos : np.ndarray y(x) trace position polynomial coefficients, shape (deg+1,). Coefficients in numpy.polyval order (highest power first). column_range : tuple[int, int] Valid x range [start, end) for this trace. height : float | None Extraction aperture height in pixels. None to use settings default. slit : np.ndarray | None Slit curvature coefficients, shape (deg_y+1, deg_x+1). Evaluates to x_offset = P(y) where P's coefficients vary with x. slit[i, :] are coefficients for the y^i term as a function of x. slitdelta : np.ndarray | None Per-row slit correction, shape (height_pixels,). Residual offsets beyond polynomial fit. wave : np.ndarray | None Wavelength polynomial coefficients. Can be: - 1D array, shape (deg+1,): per-trace polynomial, wavelength = polyval(x) - 2D array, shape (deg_x+1, deg_m+1): global 2D polynomial shared across all traces. Wavelength = polyval2d(x, m) where m is this trace's order. """ # Identity m: int | None # Geometry (required) pos: np.ndarray column_range: tuple[int, int] # Optional fields (must come after required fields) bundle: int | None = None group: str | int | None = None fiber_idx: int | None = None height: float | None = None slit: np.ndarray | None = None slitdelta: np.ndarray | None = None wave: np.ndarray | None = None _wave_idx: int | None = None # trace index for 2D polynomial evaluation invalid: str | None = None # reason if trace should be skipped
[docs] def slit_at_x(self, x: float | np.ndarray) -> np.ndarray | None: """Evaluate slit polynomial coefficients at position x. Parameters ---------- x : float or np.ndarray Column position(s) to evaluate at. Returns ------- np.ndarray or None Polynomial coefficients for y_offset = c0 + c1*y + c2*y^2 + ... Shape (deg_y+1,) for scalar x, or (len(x), deg_y+1) for array x. Returns None if no slit curvature is set. """ if self.slit is None: return None # slit[i, :] = coefficients for y^i term as function of x # Evaluate each row's polynomial at x return np.array([np.polyval(c, x) for c in self.slit])
[docs] def wlen(self, x: np.ndarray) -> np.ndarray | None: """Evaluate wavelength polynomial at column positions. Parameters ---------- x : np.ndarray Column positions to evaluate at. Returns ------- np.ndarray or None Wavelength values at each x position. Returns None if no wavelength calibration is set. """ if self.wave is None: return None if self.wave.ndim == 2: # 2D polynomial: wave[i,j] is coeff for x^i * idx^j # The polynomial is fitted with trace indices (0, 1, 2, ...), # so we must use _wave_idx (not physical order number m). idx = self._wave_idx if idx is None: logger.warning( "Cannot evaluate 2D wavelength polynomial: trace._wave_idx is None." ) return None m_arr = np.full_like(x, idx, dtype=float) return np.polynomial.polynomial.polyval2d(x, m_arr, self.wave) else: # 1D polynomial: standard polyval return np.polyval(self.wave, x)
[docs] def y_at_x(self, x: np.ndarray) -> np.ndarray: """Evaluate trace y-position at column positions. Parameters ---------- x : np.ndarray Column positions to evaluate at. Returns ------- np.ndarray Y positions of the trace center at each x. """ return np.polyval(self.pos, x)
def _validate_traces(traces: list[Trace], context: str = "") -> None: """Validate trace list invariants. Checks: 1. (m, bundle, group) is unique for grouped traces (group is not None) 2. (m, bundle, fiber_idx) is unique for individual fiber traces 3. Traces are ordered by y-position (ascending) Parameters ---------- traces : list[Trace] Traces to validate. context : str Context for error messages (e.g., file path). Raises ------ ValueError If validation fails. """ if not traces: return # Check that group and fiber_idx are mutually exclusive for i, t in enumerate(traces): if t.group is not None and t.fiber_idx is not None: raise ValueError( f"Trace {i} has both group={t.group} and fiber_idx={t.fiber_idx}. " f"These are mutually exclusive: group indicates merged fiber result, " f"fiber_idx indicates individual fiber{context}" ) # Check (m, bundle, group) uniqueness for grouped traces, and that # bundle merges encode their bundle id consistently in both fields # (group="bundle_5" must mean bundle=5). seen_group = set() bundle_pat = re.compile(r"^bundle_(\d+)$") for t in traces: if t.group is not None: key = (t.m, t.bundle, t.group) if key in seen_group: raise ValueError( f"Duplicate (m={t.m}, bundle={t.bundle}, group={t.group}) " f"in traces{context}" ) seen_group.add(key) match = bundle_pat.match(str(t.group)) if match and t.bundle != int(match.group(1)): raise ValueError( f"Trace group={t.group!r} does not match bundle={t.bundle} " f"(m={t.m}){context}" ) # Check (m, bundle, fiber_idx) uniqueness for fiber traces seen_fiber = set() for t in traces: if t.fiber_idx is not None: key = (t.m, t.bundle, t.fiber_idx) if key in seen_fiber: raise ValueError( f"Duplicate (m={t.m}, bundle={t.bundle}, " f"fiber_idx={t.fiber_idx}) in traces{context}" ) seen_fiber.add(key) # Check y-position ordering (evaluate at midpoint of column range) # Only applies to ungrouped traces - grouped traces are organized by (m, group) has_groups = any(t.group is not None for t in traces) if not has_groups: ref_x = (traces[0].column_range[0] + traces[0].column_range[1]) // 2 y_positions = [t.y_at_x(ref_x) for t in traces] for i in range(1, len(y_positions)): if y_positions[i] < y_positions[i - 1]: logger.warning( "Traces not ordered by y-position at x=%d: trace %d (y=%.1f) < trace %d (y=%.1f)%s", ref_x, i, y_positions[i], i - 1, y_positions[i - 1], context, )
[docs] def save_traces( path: str | Path, traces: list[Trace], header: fits.Header = None, steps: list[str] = None, ) -> None: """Save traces to a FITS binary table. Parameters ---------- path : str | Path Output file path. traces : list[Trace] Traces to save. header : fits.Header, optional FITS header to include. If None, a minimal header is created. steps : list[str], optional Pipeline steps that have been run (stored in E_STEPS header). Raises ------ ValueError If traces have duplicate (group, m) keys. """ if not traces: raise ValueError("Cannot save empty trace list") _validate_traces(traces, f" when saving to {path}") if header is None: header = fits.Header() else: header = header.copy() # Add format metadata header["E_FMTVER"] = (FORMAT_VERSION, "PyReduce format version") if steps: header["E_STEPS"] = (",".join(steps), "Pipeline steps run") # Determine array sizes max_pos_deg = max(len(t.pos) for t in traces) # Determine wave dimensions - can be 1D (per-trace) or 2D (global poly) wave_shapes = [t.wave.shape if t.wave is not None else () for t in traces] wave_is_2d = any(len(s) == 2 for s in wave_shapes) if wave_is_2d: max_wave_x = max((s[0] if len(s) == 2 else 0) for s in wave_shapes) max_wave_m = max((s[1] if len(s) == 2 else 0) for s in wave_shapes) max_wave_deg = 0 # Not used for 2D else: max_wave_deg = max((s[0] if len(s) >= 1 else 0) for s in wave_shapes) max_wave_x = max_wave_m = 0 max_slitdelta_len = max( (len(t.slitdelta) if t.slitdelta is not None else 0) for t in traces ) # Determine slit dimensions (deg_y+1, deg_x+1) slit_shapes = [(t.slit.shape if t.slit is not None else (0, 0)) for t in traces] max_slit_y = max(s[0] for s in slit_shapes) max_slit_x = max(s[1] for s in slit_shapes) ntrace = len(traces) # Build arrays m_arr = np.array([t.m if t.m is not None else -1 for t in traces], dtype=np.int16) bundle_arr = np.array( [t.bundle if t.bundle is not None else -1 for t in traces], dtype=np.int16 ) group_arr = np.array( [str(t.group) if t.group is not None else "" for t in traces], dtype="U16" ) fiber_idx_arr = np.array( [t.fiber_idx if t.fiber_idx is not None else -1 for t in traces], dtype=np.int16 ) col_range_arr = np.array([t.column_range for t in traces], dtype=np.int32) height_arr = np.array( [t.height if t.height is not None else np.nan for t in traces], dtype=np.float32 ) pos_arr = np.zeros((ntrace, max_pos_deg), dtype=np.float64) for i, t in enumerate(traces): pos_arr[i, : len(t.pos)] = t.pos wave_arr = None if wave_is_2d and max_wave_x > 0 and max_wave_m > 0: # 2D wavelength polynomial wave_arr = np.full((ntrace, max_wave_x, max_wave_m), np.nan, dtype=np.float64) for i, t in enumerate(traces): if t.wave is not None and t.wave.ndim == 2: wx, wm = t.wave.shape wave_arr[i, :wx, :wm] = t.wave elif max_wave_deg > 0: # 1D wavelength polynomial per trace wave_arr = np.full((ntrace, max_wave_deg), np.nan, dtype=np.float64) for i, t in enumerate(traces): if t.wave is not None and t.wave.ndim == 1: wave_arr[i, : len(t.wave)] = t.wave slit_arr = None if max_slit_y > 0 and max_slit_x > 0: slit_arr = np.full((ntrace, max_slit_y, max_slit_x), np.nan, dtype=np.float64) for i, t in enumerate(traces): if t.slit is not None: sy, sx = t.slit.shape slit_arr[i, :sy, :sx] = t.slit slitdelta_arr = None if max_slitdelta_len > 0: slitdelta_arr = np.full((ntrace, max_slitdelta_len), np.nan, dtype=np.float32) for i, t in enumerate(traces): if t.slitdelta is not None: slitdelta_arr[i, : len(t.slitdelta)] = t.slitdelta # Build FITS columns columns = [ fits.Column(name="M", format="I", array=m_arr), fits.Column(name="BUNDLE", format="I", array=bundle_arr), fits.Column(name="GROUP", format="16A", array=group_arr), fits.Column(name="FIBER_IDX", format="I", array=fiber_idx_arr), fits.Column(name="POS", format=f"{max_pos_deg}D", array=pos_arr), fits.Column(name="COL_RANGE", format="2J", array=col_range_arr), fits.Column(name="HEIGHT", format="E", array=height_arr), ] if slit_arr is not None: slit_flat = slit_arr.reshape(ntrace, -1) columns.append( fits.Column( name="SLIT", format=f"{slit_flat.shape[1]}D", array=slit_flat, dim=f"({max_slit_x},{max_slit_y})", ) ) header["SLIT_Y"] = (max_slit_y, "Slit polynomial y-degree + 1") header["SLIT_X"] = (max_slit_x, "Slit polynomial x-degree + 1") if slitdelta_arr is not None: columns.append( fits.Column( name="SLITDELTA", format=f"{max_slitdelta_len}E", array=slitdelta_arr ) ) if wave_arr is not None: if wave_is_2d: wave_flat = wave_arr.reshape(ntrace, -1) columns.append( fits.Column( name="WAVE", format=f"{wave_flat.shape[1]}D", array=wave_flat, dim=f"({max_wave_m},{max_wave_x})", ) ) header["WAVE_X"] = (max_wave_x, "Wave polynomial x-degree + 1") header["WAVE_M"] = (max_wave_m, "Wave polynomial m-degree + 1") else: columns.append( fits.Column(name="WAVE", format=f"{max_wave_deg}D", array=wave_arr) ) # Create HDU list primary = fits.PrimaryHDU(header=header) table = fits.BinTableHDU.from_columns(columns, name="TRACES") hdulist = fits.HDUList([primary, table]) hdulist.writeto(path, overwrite=True, output_verify="silentfix+ignore") logger.info("Saved %d traces to: %s", ntrace, path)
[docs] def load_traces(path: str | Path) -> tuple[list[Trace], fits.Header]: """Load traces from a FITS file. Also supports loading legacy NPZ format for backwards compatibility. Parameters ---------- path : str | Path Input file path (.fits or .npz). Returns ------- traces : list[Trace] Loaded traces. header : fits.Header FITS header (empty for NPZ files). """ path = Path(path) if path.suffix == ".npz": return _load_traces_npz(path) with fits.open(path, memmap=False) as hdu: header = hdu[0].header fmtver = header.get("E_FMTVER", 1) if fmtver < 2: logger.warning("Loading traces from old format (version %d)", fmtver) data = hdu["TRACES"].data m_arr = data["M"] bundle_arr = data["BUNDLE"] if "BUNDLE" in data.dtype.names else None # Handle both new (GROUP) and old (FIBER) column names if "GROUP" in data.dtype.names: group_arr = data["GROUP"] else: group_arr = data["FIBER"] # Backward compat with v2 fiber_idx_arr = data["FIBER_IDX"] if "FIBER_IDX" in data.dtype.names else None pos_arr = data["POS"] col_range_arr = data["COL_RANGE"] height_arr = data["HEIGHT"] slit_arr = data["SLIT"] if "SLIT" in data.dtype.names else None slitdelta_arr = data["SLITDELTA"] if "SLITDELTA" in data.dtype.names else None wave_arr = data["WAVE"] if "WAVE" in data.dtype.names else None # Reshape slit if present if slit_arr is not None: slit_y = header.get("SLIT_Y", 0) slit_x = header.get("SLIT_X", 0) if slit_y > 0 and slit_x > 0: slit_arr = slit_arr.reshape(-1, slit_y, slit_x) # Reshape wave if 2D polynomial wave_is_2d = False if wave_arr is not None: wave_x = header.get("WAVE_X", 0) wave_m = header.get("WAVE_M", 0) if wave_x > 0 and wave_m > 0: wave_arr = wave_arr.reshape(-1, wave_x, wave_m) wave_is_2d = True traces = [] for i in range(len(m_arr)): m = int(m_arr[i]) if m_arr[i] >= 0 else None group = group_arr[i].strip() # Empty string or "0" means no group (backward compat) if group == "" or group == "0": group = None else: # Try to convert group to int if it looks like one try: group = int(group) except ValueError: pass fiber_idx = ( int(fiber_idx_arr[i]) if fiber_idx_arr is not None and fiber_idx_arr[i] >= 0 else None ) bundle = ( int(bundle_arr[i]) if bundle_arr is not None and bundle_arr[i] >= 0 else None ) # Remove trailing NaN/zeros from pos pos = pos_arr[i] pos = pos[~np.isnan(pos)] if np.any(np.isnan(pos)) else pos column_range = (int(col_range_arr[i, 0]), int(col_range_arr[i, 1])) height = float(height_arr[i]) if not np.isnan(height_arr[i]) else None slit = None if slit_arr is not None: slit = slit_arr[i] if np.all(np.isnan(slit)): slit = None else: # Remove all-NaN rows/cols mask_y = ~np.all(np.isnan(slit), axis=1) mask_x = ~np.all(np.isnan(slit), axis=0) slit = slit[mask_y][:, mask_x] slitdelta = None if slitdelta_arr is not None: slitdelta = slitdelta_arr[i] if np.all(np.isnan(slitdelta)): slitdelta = None else: slitdelta = slitdelta[~np.isnan(slitdelta)] wave = None if wave_arr is not None: wave = wave_arr[i] if np.all(np.isnan(wave)): wave = None elif wave_is_2d: # 2D polynomial - remove all-NaN rows/cols mask_x = ~np.all(np.isnan(wave), axis=1) mask_m = ~np.all(np.isnan(wave), axis=0) wave = wave[mask_x][:, mask_m] else: # 1D polynomial - remove trailing NaN wave = wave[~np.isnan(wave)] traces.append( Trace( m=m, bundle=bundle, group=group, fiber_idx=fiber_idx, pos=pos, column_range=column_range, height=height, slit=slit, slitdelta=slitdelta, wave=wave, ) ) # Reconstruct _wave_idx for 2D wave polynomials. # The 2D polynomial was fitted with trace index within group as the # order coordinate, so we assign sequential indices per group. if wave_is_2d: group_counters: dict = {} for t in traces: if t.wave is not None and t.wave.ndim == 2: g = t.group idx = group_counters.get(g, 0) t._wave_idx = idx group_counters[g] = idx + 1 logger.info("Loaded %d traces from: %s", len(traces), path) _validate_traces(traces, f" loaded from {path}") return traces, header
def _load_traces_npz(path: Path) -> tuple[list[Trace], fits.Header]: """Load traces from legacy NPZ format. This handles the old format where traces, column_range, and heights were stored as separate arrays without order/fiber identity. Parameters ---------- path : Path Input NPZ file path. Returns ------- traces : list[Trace] Loaded traces (m and fiber assigned sequentially). header : fits.Header Empty header. """ data = np.load(path, allow_pickle=True) # Handle old 'orders' key name if "orders" in data and "traces" not in data: trace_coeffs = data["orders"] else: trace_coeffs = data["traces"] column_range = data["column_range"] # Heights may or may not be present heights = data.get("heights", None) if heights is not None and heights.ndim == 0: heights = None traces = [] for i in range(len(trace_coeffs)): height = ( float(heights[i]) if heights is not None and not np.isnan(heights[i]) else None ) traces.append( Trace( m=i, # Sequential order number (no identity preserved) pos=trace_coeffs[i], column_range=(int(column_range[i, 0]), int(column_range[i, 1])), height=height, ) ) logger.info("Loaded %d traces from legacy NPZ: %s", len(traces), path) _validate_traces(traces, f" loaded from {path}") return traces, fits.Header()