"""
Spectrum data model for PyReduce.
This module defines the Spectrum, ExtractionParams, and Spectra classes
for storing extracted spectral data.
Replaces the legacy Echelle class with cleaner semantics:
- NaN masking instead of COLUMNS + MASK redundancy
- Per-trace metadata (m, fiber, extraction_height)
- Un-normalized spectra with separate continuum
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
import astropy.io.fits as fits
import numpy as np
import scipy.constants
if TYPE_CHECKING:
from pyreduce.trace_model import Trace
logger = logging.getLogger(__name__)
# Format version for backwards compatibility detection
FORMAT_VERSION = 2
[docs]
@dataclass
class Spectrum:
"""Output of extraction for one trace.
Attributes
----------
m : int | None
Spectral order number. None if unknown.
group : str | int | None
Group identifier (e.g., 'A', 'B', 'cal' or bundle index).
None if trace has not been through fiber grouping.
fiber_idx : int | None
Fiber index within group (1-indexed). None if unknown.
spec : np.ndarray
Flux values, un-normalized. NaN for masked pixels.
sig : np.ndarray
Uncertainty values. NaN for masked pixels.
wave : np.ndarray | None
Wavelength values (evaluated from polynomial). Same length as spec.
cont : np.ndarray | None
Continuum values (full array, not polynomial). Same length as spec.
slitfu : np.ndarray | None
Slit function (shape depends on osample: height * osample + 1).
extraction_height : float | None
Extraction aperture used for this trace.
"""
# Identity (copied from Trace)
m: int | None
# Extracted data (NaN for masked pixels)
spec: np.ndarray
sig: np.ndarray
# Optional fields (must come after required fields)
group: str | int | None = None
fiber_idx: int | None = None
wave: np.ndarray | None = None
cont: np.ndarray | None = None
slitfu: np.ndarray | None = None
extraction_height: float | None = None
[docs]
@classmethod
def from_trace(
cls, trace: Trace, spec: np.ndarray, sig: np.ndarray, **kwargs
) -> Spectrum:
"""Factory method that copies identity from Trace.
Parameters
----------
trace : Trace
Source trace for identity (m, group, fiber_idx).
spec : np.ndarray
Extracted spectrum.
sig : np.ndarray
Spectrum uncertainty.
**kwargs
Additional fields (wave, cont, slitfu, extraction_height).
Returns
-------
Spectrum
New spectrum with identity copied from trace.
"""
return cls(
m=trace.m,
group=trace.group,
fiber_idx=trace.fiber_idx,
spec=spec,
sig=sig,
**kwargs,
)
[docs]
def normalized(self) -> tuple[np.ndarray, np.ndarray]:
"""Return continuum-normalized spectrum and uncertainty.
Returns
-------
spec_norm : np.ndarray
Spectrum divided by continuum.
sig_norm : np.ndarray
Uncertainty divided by continuum.
Raises
------
ValueError
If no continuum is available.
"""
if self.cont is None:
raise ValueError("No continuum available for normalization")
return self.spec / self.cont, self.sig / self.cont
@property
def mask(self) -> np.ndarray:
"""Boolean mask where True indicates invalid (masked) pixels."""
return np.isnan(self.spec)
[docs]
@dataclass
class Spectra:
"""Container for multiple spectra from one observation.
Replaces the legacy Echelle class.
Attributes
----------
header : fits.Header
FITS header with observation metadata.
data : list[Spectrum]
Individual spectra, one per trace.
params : ExtractionParams | None
Global extraction parameters (stored in header).
"""
header: fits.Header
data: list[Spectrum]
params: ExtractionParams | None = None
@property
def ntrace(self) -> int:
"""Number of traces/spectra."""
return len(self.data)
@property
def ncol(self) -> int:
"""Number of columns (pixels) in spectra."""
if not self.data:
return 0
return len(self.data[0].spec)
[docs]
def select(
self,
m: int | None = None,
group: str | int | None = None,
fiber_idx: int | None = None,
) -> list[Spectrum]:
"""Filter spectra by order, group, and/or fiber index.
Parameters
----------
m : int, optional
Select spectra with this spectral order number.
group : str or int, optional
Select spectra with this group identifier.
fiber_idx : int, optional
Select spectra with this fiber index.
Returns
-------
list[Spectrum]
Matching spectra.
"""
result = self.data
if m is not None:
result = [s for s in result if s.m == m]
if group is not None:
result = [s for s in result if s.group == group]
if fiber_idx is not None:
result = [s for s in result if s.fiber_idx == fiber_idx]
return result
[docs]
def get_arrays(self) -> dict[str, np.ndarray]:
"""Get stacked arrays for all spectra.
Returns
-------
dict
Dictionary with keys 'spec', 'sig', 'wave', 'cont', 'm', 'group', 'fiber_idx'.
Arrays are stacked along axis 0 (one row per trace).
"""
return {
"spec": np.array([s.spec for s in self.data]),
"sig": np.array([s.sig for s in self.data]),
"wave": np.array([s.wave for s in self.data])
if self.data[0].wave is not None
else None,
"cont": np.array([s.cont for s in self.data])
if self.data[0].cont is not None
else None,
"m": np.array([s.m if s.m is not None else -1 for s in self.data]),
"group": np.array(
[str(s.group) if s.group is not None else "" for s in self.data]
),
"fiber_idx": np.array(
[s.fiber_idx if s.fiber_idx is not None else -1 for s in self.data]
),
}
[docs]
@staticmethod
def read(
fname: str | Path,
raw: bool = False,
continuum_normalization: bool = False,
barycentric_correction: bool = True,
radial_velocity_correction: bool = True,
) -> Spectra:
"""Read spectra from a FITS file.
Supports both new format (E_FMTVER >= 2) and legacy format.
Parameters
----------
fname : str or Path
Input file path.
raw : bool
If True, skip all corrections.
continuum_normalization : bool
Apply continuum normalization (default False for new format).
barycentric_correction : bool
Apply barycentric correction to wavelength.
radial_velocity_correction : bool
Apply radial velocity correction to wavelength.
Returns
-------
Spectra
Loaded spectra.
"""
with fits.open(fname, memmap=False) as hdu:
header = hdu[0].header
table = hdu[1].data
fmtver = header.get("E_FMTVER", 1)
if fmtver >= 2:
return _read_new_format(
fname,
header,
table,
raw,
continuum_normalization,
barycentric_correction,
radial_velocity_correction,
)
else:
return _read_legacy_format(
fname,
header,
table,
raw,
continuum_normalization,
barycentric_correction,
radial_velocity_correction,
)
[docs]
def save(self, fname: str | Path, steps: list[str] = None) -> None:
"""Save spectra to a FITS file.
Parameters
----------
fname : str or Path
Output file path.
steps : list[str], optional
Pipeline steps that have been run.
"""
header = self.header.copy() if self.header else fits.Header()
# Add format metadata
header["E_FMTVER"] = (FORMAT_VERSION, "PyReduce format version")
if steps:
header["E_STEPS"] = (",".join(steps), "Pipeline steps run")
# Add extraction parameters
if self.params:
self.params.to_header(header)
ntrace = self.ntrace
ncol = self.ncol
# Stack arrays
spec_arr = np.array([s.spec for s in self.data], dtype=np.float32)
sig_arr = np.array([s.sig for s in self.data], dtype=np.float32)
m_arr = np.array(
[s.m if s.m is not None else -1 for s in self.data], dtype=np.int16
)
group_arr = np.array([str(s.group) for s in self.data], dtype="U16")
fiber_idx_arr = np.array(
[s.fiber_idx if s.fiber_idx is not None else -1 for s in self.data],
dtype=np.int16,
)
height_arr = np.array(
[
s.extraction_height if s.extraction_height is not None else np.nan
for s in self.data
],
dtype=np.float32,
)
# Build columns
columns = [
fits.Column(name="SPEC", format=f"{ncol}E", array=spec_arr),
fits.Column(name="SIG", format=f"{ncol}E", array=sig_arr),
fits.Column(name="M", format="I", array=m_arr),
fits.Column(name="GROUP", format="16A", array=group_arr),
fits.Column(name="FIBER_IDX", format="I", array=fiber_idx_arr),
fits.Column(name="EXTR_H", format="E", array=height_arr),
]
# Optional: wavelength
has_wave = any(s.wave is not None for s in self.data)
if has_wave:
wave_arr = np.array(
[
s.wave if s.wave is not None else np.full(ncol, np.nan)
for s in self.data
],
dtype=np.float64,
)
columns.append(fits.Column(name="WAVE", format=f"{ncol}D", array=wave_arr))
# Optional: continuum
has_cont = any(s.cont is not None for s in self.data)
if has_cont:
cont_arr = np.array(
[
s.cont if s.cont is not None else np.full(ncol, np.nan)
for s in self.data
],
dtype=np.float32,
)
columns.append(fits.Column(name="CONT", format=f"{ncol}E", array=cont_arr))
# Optional: slit function
has_slitfu = any(s.slitfu is not None for s in self.data)
if has_slitfu:
max_slitfu_len = max(
len(s.slitfu) if s.slitfu is not None else 0 for s in self.data
)
slitfu_arr = np.full((ntrace, max_slitfu_len), np.nan, dtype=np.float32)
for i, s in enumerate(self.data):
if s.slitfu is not None:
slitfu_arr[i, : len(s.slitfu)] = s.slitfu
columns.append(
fits.Column(
name="SLITFU", format=f"{max_slitfu_len}E", array=slitfu_arr
)
)
# Create HDU list
primary = fits.PrimaryHDU(header=header)
table = fits.BinTableHDU.from_columns(columns, name="SPECTRA")
hdulist = fits.HDUList([primary, table])
hdulist.writeto(fname, overwrite=True, output_verify="silentfix+ignore")
logger.info("Saved %d spectra to: %s", ntrace, fname)
def _read_new_format(
fname,
header,
table,
raw,
continuum_normalization,
barycentric_correction,
radial_velocity_correction,
) -> Spectra:
"""Read new format (E_FMTVER >= 2) spectra."""
spec_arr = table["SPEC"]
sig_arr = table["SIG"]
m_arr = table["M"]
# Handle both new (GROUP) and old (FIBER) column names
if "GROUP" in table.dtype.names:
group_arr = table["GROUP"]
else:
group_arr = table["FIBER"] # Backward compat
fiber_idx_arr = table["FIBER_IDX"] if "FIBER_IDX" in table.dtype.names else None
height_arr = table["EXTR_H"] if "EXTR_H" in table.dtype.names else None
wave_arr = table["WAVE"] if "WAVE" in table.dtype.names else None
cont_arr = table["CONT"] if "CONT" in table.dtype.names else None
slitfu_arr = table["SLITFU"] if "SLITFU" in table.dtype.names else None
# Apply wavelength corrections
if not raw and wave_arr is not None:
velocity_correction = 0
if barycentric_correction:
velocity_correction -= header.get("barycorr", 0)
if radial_velocity_correction:
velocity_correction += header.get("radvel", 0)
if velocity_correction != 0:
speed_of_light = scipy.constants.speed_of_light * 1e-3
wave_arr = wave_arr * (1 + velocity_correction / speed_of_light)
# Apply continuum normalization if requested
if not raw and continuum_normalization and cont_arr is not None:
spec_arr = spec_arr / cont_arr
sig_arr = sig_arr / cont_arr
params = ExtractionParams.from_header(header)
spectra = []
for i in range(len(spec_arr)):
m = int(m_arr[i]) if m_arr[i] >= 0 else None
group = group_arr[i].strip()
# Empty string or "0" means no group (backward compat)
if group == "" or group == "0":
group = None
else:
try:
group = int(group)
except ValueError:
pass
fiber_idx = (
int(fiber_idx_arr[i])
if fiber_idx_arr is not None and fiber_idx_arr[i] >= 0
else None
)
spec = spec_arr[i]
sig = sig_arr[i]
wave = wave_arr[i] if wave_arr is not None else None
cont = cont_arr[i] if cont_arr is not None else None
slitfu = slitfu_arr[i] if slitfu_arr is not None else None
height = (
float(height_arr[i])
if height_arr is not None and not np.isnan(height_arr[i])
else None
)
# Remove NaN-padding from slitfu
if slitfu is not None:
slitfu = slitfu[~np.isnan(slitfu)]
if len(slitfu) == 0:
slitfu = None
spectra.append(
Spectrum(
m=m,
group=group,
fiber_idx=fiber_idx,
spec=spec,
sig=sig,
wave=wave,
cont=cont,
slitfu=slitfu,
extraction_height=height,
)
)
return Spectra(header=header, data=spectra, params=params)
def _read_legacy_format(
fname,
header,
table,
raw,
continuum_normalization,
barycentric_correction,
radial_velocity_correction,
) -> Spectra:
"""Read legacy format (E_FMTVER < 2 or missing) spectra."""
# Legacy format uses lowercase keys from the table
_data = {col.lower(): table[col][0] for col in table.dtype.names}
spec_arr = _data.get("spec")
sig_arr = _data.get("sig")
wave_arr = _data.get("wave")
cont_arr = _data.get("cont")
columns_arr = _data.get("columns")
if spec_arr is None:
raise ValueError(f"No SPEC column found in {fname}")
ntrace, ncol = spec_arr.shape
# Expand polynomials if needed
if not raw:
if wave_arr is not None:
wave_arr = _expand_polynomial(ncol, wave_arr)
velocity_correction = 0
if barycentric_correction:
velocity_correction -= header.get("barycorr", 0)
if radial_velocity_correction:
velocity_correction += header.get("radvel", 0)
if velocity_correction != 0:
speed_of_light = scipy.constants.speed_of_light * 1e-3
wave_arr = wave_arr * (1 + velocity_correction / speed_of_light)
if cont_arr is not None:
cont_arr = _expand_polynomial(ncol, cont_arr)
# Create mask from columns (legacy behavior)
if columns_arr is not None:
mask = np.full((ntrace, ncol), True)
for i in range(ntrace):
mask[i, columns_arr[i, 0] : columns_arr[i, 1]] = False
# Apply mask as NaN
spec_arr = np.where(mask, np.nan, spec_arr)
sig_arr = np.where(mask, np.nan, sig_arr)
# Apply continuum normalization
if not raw and continuum_normalization and cont_arr is not None:
spec_arr = spec_arr / cont_arr
sig_arr = sig_arr / cont_arr
# Build spectra (no m/group info in legacy format)
spectra = []
for i in range(ntrace):
spectra.append(
Spectrum(
m=i, # Sequential, no real order number
spec=spec_arr[i],
sig=sig_arr[i],
wave=wave_arr[i] if wave_arr is not None else None,
cont=cont_arr[i] if cont_arr is not None else None,
)
)
return Spectra(header=header, data=spectra, params=None)
def _expand_polynomial(ncol: int, poly: np.ndarray) -> np.ndarray:
"""Expand polynomial coefficients to full array if needed.
Handles three formats:
1. 1D array (REDUCE make_wave format) - full 2D expansion
2. Shape (ntrace, degree+1) - 1D polynomial per order
3. Already expanded (ntrace, ncol) - pass through
"""
if poly.ndim == 1:
return _calc_2dpolynomial(poly)
elif poly.shape[1] < 20:
return _calc_1dpolynomials(ncol, poly)
return poly
def _calc_2dpolynomial(solution2d: np.ndarray) -> np.ndarray:
"""Expand a 2D polynomial in REDUCE make_wave format."""
ncol = int(solution2d[1])
ntrace = int(solution2d[2])
order_base = int(solution2d[3])
deg_cross, deg_column, deg_order = (
int(solution2d[7]),
int(solution2d[8]),
int(solution2d[9]),
)
coeff_in = solution2d[10:]
coeff = np.zeros((deg_order + 1, deg_column + 1))
coeff[0, 0] = coeff_in[0]
coeff[0, 1:] = coeff_in[1 : 1 + deg_column]
coeff[1:, 0] = coeff_in[1 + deg_column : 1 + deg_column + deg_order]
if deg_cross in [4, 6]:
coeff[1, 1] = coeff_in[deg_column + deg_order + 1]
coeff[1, 2] = coeff_in[deg_column + deg_order + 2]
coeff[2, 1] = coeff_in[deg_column + deg_order + 3]
coeff[2, 2] = coeff_in[deg_column + deg_order + 4]
if deg_cross == 6:
coeff[1, 3] = coeff_in[deg_column + deg_order + 5]
coeff[3, 1] = coeff_in[deg_column + deg_order + 6]
x = np.arange(order_base, order_base + ntrace, dtype=float)
y = np.arange(ncol, dtype=float)
return np.polynomial.polynomial.polygrid2d(x / 100, y / 1000, coeff) / x[:, None]
def _calc_1dpolynomials(ncol: int, poly: np.ndarray) -> np.ndarray:
"""Expand 1D polynomials (one per order)."""
ntrace = poly.shape[0]
x = np.arange(ncol)
result = np.zeros((ntrace, ncol))
for i, coef in enumerate(poly):
result[i] = np.polyval(coef, x)
return result
# Convenience function for backwards compatibility
[docs]
def read(fname: str | Path, **kwargs) -> Spectra:
"""Read spectra from a file.
Alias for Spectra.read().
"""
return Spectra.read(fname, **kwargs)