"""
Fluent Pipeline API for PyReduce.
Provides a cleaner interface for building and running reduction pipelines.
Wraps the existing Step classes internally for backward compatibility.
Example usage:
from pyreduce.pipeline import Pipeline
# Simple: auto-discover files for an instrument
result = Pipeline.from_instrument(
instrument="UVES",
target="HD132205",
night="2010-04-01",
channel="middle",
base_dir="/data",
).run()
# Or build manually with explicit files:
result = (
Pipeline("UVES", output_dir, config=settings)
.bias(bias_files)
.flat(flat_files)
.trace()
.extract(science_files)
.run()
)
# For multi-fiber instruments with separate illumination files:
t1, cr1 = pipe.trace_raw([even_flat])
t2, cr2 = pipe.trace_raw([odd_flat])
pipe.organize(t1, cr1, t2, cr2)
pipe.extract([science_file]).run()
"""
from __future__ import annotations
import logging
import os
from os.path import join
from typing import TYPE_CHECKING
import numpy as np
from . import util
from .configuration import load_config
from .instruments.instrument_info import load_instrument
from .reduce import (
BackgroundScatter,
Bias,
ContinuumNormalization,
Finalize,
Flat,
LaserFrequencyCombFinalize,
LaserFrequencyCombMaster,
Mask,
NormalizeFlatField,
RectifyImage,
ScienceExtraction,
SlitCurvatureDetermination,
Trace,
WavelengthCalibrationFinalize,
WavelengthCalibrationInitialize,
WavelengthCalibrationMaster,
)
if TYPE_CHECKING:
from .instruments.common import Instrument
logger = logging.getLogger(__name__)
[docs]
class Pipeline:
"""Fluent API for building reduction pipelines."""
STEP_CLASSES = {
"mask": Mask,
"bias": Bias,
"flat": Flat,
"trace": Trace,
"scatter": BackgroundScatter,
"norm_flat": NormalizeFlatField,
"wavecal_master": WavelengthCalibrationMaster,
"wavecal_init": WavelengthCalibrationInitialize,
"wavecal": WavelengthCalibrationFinalize,
"freq_comb_master": LaserFrequencyCombMaster,
"freq_comb": LaserFrequencyCombFinalize,
"curvature": SlitCurvatureDetermination,
"science": ScienceExtraction,
"continuum": ContinuumNormalization,
"finalize": Finalize,
"rectify": RectifyImage,
}
STEP_ORDER = {
"mask": 5,
"bias": 10,
"flat": 20,
"trace": 30,
"curvature": 40,
"scatter": 45,
"norm_flat": 50,
"wavecal_master": 60,
"wavecal_init": 64,
"wavecal": 67,
"freq_comb_master": 70,
"freq_comb": 72,
"rectify": 75,
"science": 80,
"continuum": 90,
"finalize": 100,
}
def __init__(
self,
instrument: Instrument | str,
output_dir: str,
target: str = "",
channel: str = "",
night: str = "",
config: dict | None = None,
trace_range: tuple[int, int] | None = None,
plot: int = 0,
plot_dir: str | None = None,
):
"""Initialize a reduction pipeline.
Parameters
----------
instrument : Instrument or str
Instrument instance or name to load
output_dir : str
Directory for output files
target : str, optional
Target name for output file naming
channel : str, optional
Instrument channel (e.g., "RED", "BLUE")
night : str, optional
Observation night string
config : dict, optional
Configuration dict with step-specific settings
trace_range : tuple, optional
(first, last+1) orders to process
plot : int, optional
Plot level (0=off, 1=basic, 2=detailed). Default 0.
plot_dir : str, optional
Directory to save plots as PNG files. If None, plots are shown interactively.
"""
if isinstance(instrument, str):
instrument = load_instrument(instrument)
self.instrument = instrument
self.output_dir = output_dir.format(
instrument=instrument.name.upper(),
target=target,
night=night,
channel=channel,
)
self.target = target
self.channel = channel
self.night = night
self.config = config or {}
self.trace_range = trace_range
self.plot = plot
self.plot_dir = plot_dir
# Set global plot directory for util.show_or_save()
util.set_plot_dir(plot_dir)
self._steps: list[tuple[str, list | None]] = []
self._data: dict = {}
self._files: dict = {}
def _add_step(self, name: str, files: list | None = None) -> Pipeline:
"""Add a step to the pipeline."""
self._steps.append((name, files))
if files is not None:
self._files[name] = files
return self
# Step methods - fluent API
[docs]
def mask(self) -> Pipeline:
"""Load or create bad pixel mask."""
return self._add_step("mask")
[docs]
def bias(self, files: list[str]) -> Pipeline:
"""Combine bias frames into master bias."""
return self._add_step("bias", files)
[docs]
def flat(self, files: list[str]) -> Pipeline:
"""Combine flat frames into master flat."""
return self._add_step("flat", files)
[docs]
def trace(self, files: list[str] | None = None) -> Pipeline:
"""Trace fibers/orders on flat field.
Parameters
----------
files : list[str], optional
Files to use for tracing. If not provided, uses flat from previous step.
Returns
-------
Pipeline
Self for method chaining
"""
return self._add_step("trace", files)
[docs]
def trace_raw(self, files: list[str], order_centers: dict = None) -> list:
"""Trace fibers/orders and return Trace objects without storing.
Use this for multi-file tracing workflows where you need to combine
traces from multiple files before grouping.
Parameters
----------
files : list[str]
Files to use for tracing.
order_centers : dict[int, float], optional
Order number -> y-position mapping for m assignment.
Returns
-------
list[Trace]
Trace objects with fiber_idx set (individual fibers, not grouped).
"""
from .trace import trace as trace_func
# Get mask and bias if available
mask = self._data.get("mask")
bias = self._data.get("bias")
# Load and calibrate the image
step_config = self.config.get("trace", {}).copy()
step_config["plot"] = self.plot
step = Trace(*self._get_step_inputs(), **step_config)
order_img, _ = step.calibrate(files, mask, bias, None)
# Load order_centers from config if not provided
if order_centers is None:
order_centers = step._load_order_centers()
traces = trace_func(
order_img,
min_cluster=step.min_cluster,
min_width=step.min_width,
filter_x=step.filter_x,
filter_y=step.filter_y,
filter_type=step.filter_type,
noise=step.noise,
noise_relative=step.noise_relative,
degree=step.fit_degree,
degree_before_merge=step.degree_before_merge,
regularization=step.regularization,
closing_shape=step.closing_shape,
opening_shape=step.opening_shape,
border_width=step.border_width,
manual=step.manual,
auto_merge_threshold=step.auto_merge_threshold,
merge_min_threshold=step.merge_min_threshold,
sigma=step.sigma,
plot=self.plot,
plot_title=step.plot_title,
order_centers=order_centers,
)
return traces
[docs]
def organize(self, traces: list, *more) -> Pipeline:
"""Organize traces into fiber groups based on instrument config.
Use this after trace_raw() to apply fiber grouping configuration.
Can accept multiple trace lists which will be concatenated.
Parameters
----------
traces : list[Trace]
Trace objects from trace_raw()
*more : additional list[Trace]
Optional additional trace lists to concatenate
Returns
-------
Pipeline
Self for method chaining
Example
-------
>>> t1 = pipe.trace_raw([even_flat])
>>> t2 = pipe.trace_raw([odd_flat])
>>> pipe.organize(t1, t2)
>>> pipe.extract([science_file]).run()
"""
from .trace import group_fibers
# Concatenate multiple trace lists if provided
all_traces = list(traces)
for t in more:
all_traces.extend(t)
# Sort by y-position
def sort_key(t):
x_mid = sum(t.column_range) / 2
return t.y_at_x(x_mid)
all_traces.sort(key=sort_key)
# Get config
fibers_config = getattr(self.instrument.config, "fibers", None)
step_config = self.config.get("trace", {}).copy()
degree = step_config.get("degree", 4)
# Group fibers if configured
if fibers_config is not None and (
fibers_config.groups is not None or fibers_config.bundles is not None
):
logger.info("Grouping %d traces into fiber groups", len(all_traces))
trace_objects = group_fibers(all_traces, fibers_config, degree=degree)
else:
trace_objects = all_traces
self._data["trace"] = trace_objects
# Save to disk
step_config["plot"] = self.plot
step = Trace(*self._get_step_inputs(), **step_config)
step.trace_objects = trace_objects
step.heights = np.array(
[t.height if t.height is not None else np.nan for t in trace_objects]
)
step.save()
return self
[docs]
def curvature(self, files: list[str] | None = None) -> Pipeline:
"""Determine slit curvature (p1/p2)."""
return self._add_step("curvature", files)
[docs]
def scatter(self, files: list[str] | None = None) -> Pipeline:
"""Fit background scatter model."""
return self._add_step("scatter", files)
[docs]
def normalize_flat(self) -> Pipeline:
"""Normalize flat field, extract blaze function."""
return self._add_step("norm_flat")
[docs]
def wavecal_master(self, files: list[str]) -> Pipeline:
"""Extract wavelength calibration spectrum."""
return self._add_step("wavecal_master", files)
[docs]
def wavecal_init(self) -> Pipeline:
"""Initialize wavelength solution from line atlas."""
return self._add_step("wavecal_init")
[docs]
def wavecal(self) -> Pipeline:
"""Finalize wavelength calibration."""
return self._add_step("wavecal")
[docs]
def wavelength_calibration(self, files: list[str]) -> Pipeline:
"""Full wavelength calibration (master + init + finalize)."""
return self.wavecal_master(files).wavecal_init().wavecal()
[docs]
def freq_comb_master(self, files: list[str]) -> Pipeline:
"""Extract laser frequency comb spectrum."""
return self._add_step("freq_comb_master", files)
[docs]
def freq_comb(self) -> Pipeline:
"""Finalize frequency comb calibration."""
return self._add_step("freq_comb")
[docs]
def continuum(self) -> Pipeline:
"""Normalize continuum."""
return self._add_step("continuum")
[docs]
def finalize(self) -> Pipeline:
"""Write final output files."""
return self._add_step("finalize")
[docs]
def rectify(self) -> Pipeline:
"""Rectify 2D image."""
return self._add_step("rectify")
# Loading intermediate results
[docs]
def load(self, step: str, data=None) -> Pipeline:
"""Load intermediate result instead of computing.
Parameters
----------
step : str
Name of step whose output to load
data : any, optional
Data to use directly instead of loading from disk
"""
if data is not None:
self._data[step] = data
else:
# Will be loaded during run()
self._data[step] = None # Marker to load
return self
# Execution
def _get_step_inputs(self) -> tuple:
"""Get the standard inputs for Step classes."""
return (
self.instrument,
self.channel,
self.target,
self.night,
self.output_dir,
self.trace_range,
)
def _run_step(self, name: str, files: list | None, load_only: bool = False):
"""Run or load a single step."""
step_class = self.STEP_CLASSES[name]
step_config = self.config.get(name, {}).copy()
step_config["plot"] = self.plot # Runtime plot setting
step = step_class(*self._get_step_inputs(), **step_config)
step.files = self._files # Make input files available to step
# Get dependencies
deps = step.loadDependsOn if load_only else step.dependsOn
for dep in deps:
if dep not in self._data:
self._ensure_dependency(dep)
dep_args = {d: self._data[d] for d in deps}
if load_only:
try:
logger.info("Loading data from step '%s'", name)
result = step.load(**dep_args)
return result
except FileNotFoundError:
if files is None:
raise FileNotFoundError(
f"No saved data for step '{name}' and no input files provided to run it."
) from None
logger.warning(
"Intermediate files for step '%s' not found, running instead.",
name,
)
return self._run_step(name, files, load_only=False)
logger.info("Running step '%s'", name)
if files is not None:
dep_args["files"] = files
result = step.run(**dep_args)
return result
def _ensure_dependency(self, name: str):
"""Ensure a dependency is available (load if needed)."""
if name in self._data:
return
# 'config' is a special dependency - it's the full config dict, not a step
if name == "config":
self._data["config"] = self.config
return
files = self._files.get(name)
self._data[name] = self._run_step(name, files, load_only=True)
[docs]
def run(self, skip_existing: bool = False) -> dict:
"""Execute all queued steps.
Parameters
----------
skip_existing : bool
If True, skip steps whose output files already exist
Returns
-------
dict
Results keyed by step name
"""
# Create output directory
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
# Sort steps by execution order
sorted_steps = sorted(self._steps, key=lambda x: self.STEP_ORDER.get(x[0], 999))
for name, files in sorted_steps:
# Check if already computed
if name in self._data and self._data[name] is not None:
continue
result = self._run_step(name, files)
self._data[name] = result
util.show_all()
return self._data
@property
def results(self) -> dict:
"""Access results after run()."""
return self._data
[docs]
@classmethod
def from_files(
cls,
files: dict,
output_dir: str,
target: str,
instrument,
channel: str,
night: str,
config: dict,
trace_range=None,
steps="all",
plot: int = 0,
plot_dir: str | None = None,
) -> Pipeline:
"""Create pipeline from a files dict and run specified steps.
This provides a simpler interface similar to the legacy Reducer class.
Parameters
----------
files : dict
Files for each step (bias, flat, orders, wavecal, science, etc.)
output_dir : str
Output directory
target : str
Target name
instrument : Instrument or str
Instrument instance or name
channel : str
Instrument channel
night : str
Observation night
config : dict
Configuration dict
trace_range : tuple, optional
Order range to process
steps : list or "all"
Steps to run
plot : int, optional
Plot level (0=off, 1=basic, 2=detailed). Default 0.
plot_dir : str, optional
Directory to save plots as PNG files. If None, plots are shown interactively.
Returns
-------
Pipeline
Configured pipeline ready to run
"""
pipe = cls(
instrument=instrument,
output_dir=output_dir,
target=target,
channel=channel,
night=night,
config=config,
trace_range=trace_range,
plot=plot,
plot_dir=plot_dir,
)
if steps == "all":
steps = list(cls.STEP_ORDER.keys())
# Register files for steps that may be needed as dependencies
# (even if the step itself isn't in the steps list)
for key in [
"bias",
"flat",
"trace",
"curvature",
"scatter",
"wavecal_master",
"freq_comb_master",
"science",
]:
if key in files and len(files.get(key, [])):
pipe._files[key] = files[key]
# Map step names to pipeline methods
# Use len() for truth checks since files can be numpy arrays
step_map = {
"bias": lambda: pipe.bias(files.get("bias", []))
if len(files.get("bias", []))
else pipe,
"flat": lambda: pipe.flat(files.get("flat", []))
if len(files.get("flat", []))
else pipe,
"trace": lambda: pipe.trace(files.get("trace", files.get("flat"))),
"curvature": lambda: pipe.curvature(
files.get("curvature", files.get("flat"))
),
"scatter": lambda: pipe.scatter(files.get("scatter", files.get("flat"))),
"norm_flat": lambda: pipe.normalize_flat(),
"wavecal_master": lambda: pipe.wavecal_master(
files.get("wavecal_master", [])
)
if len(files.get("wavecal_master", []))
else pipe,
"wavecal_init": lambda: pipe.wavecal_init(),
"wavecal": lambda: pipe.wavecal(),
"freq_comb_master": lambda: pipe.freq_comb_master(
files.get("freq_comb_master", [])
)
if len(files.get("freq_comb_master", []))
else pipe,
"freq_comb": lambda: pipe.freq_comb(),
"rectify": lambda: pipe.rectify(),
"science": lambda: pipe.extract(files.get("science", []))
if len(files.get("science", []))
else pipe,
"continuum": lambda: pipe.continuum(),
"finalize": lambda: pipe.finalize(),
}
for step in steps:
if step in step_map:
step_map[step]()
return pipe
[docs]
@classmethod
def from_instrument(
cls,
instrument: str,
target: str,
night: str | None = None,
channel: str | None = None,
steps: tuple | list | str = "all",
base_dir: str | None = None,
input_dir: str | None = None,
output_dir: str | None = None,
configuration: dict | None = None,
trace_range: tuple[int, int] | None = None,
plot: int = 0,
plot_dir: str | None = None,
) -> Pipeline:
"""Create pipeline from instrument name with automatic file discovery.
This is the recommended entry point for running reductions. It handles
loading the instrument, finding and sorting files, and setting up
the pipeline with the correct configuration.
Parameters
----------
instrument : str
Instrument name (e.g., "UVES", "HARPS", "XSHOOTER")
target : str
Target name or regex pattern to match in headers
night : str, optional
Observation night (YYYY-MM-DD format or regex)
channel : str, optional
Instrument channel (e.g., "RED", "BLUE", "middle"). If None,
uses all available channels for the instrument.
steps : tuple, list, or "all"
Steps to run. Default "all" runs all applicable steps.
base_dir : str, optional
Base directory for data. Default: $REDUCE_DATA or ~/REDUCE_DATA
input_dir : str, optional
Input directory relative to base_dir. Default: from config
output_dir : str, optional
Output directory relative to base_dir. Default: from config
configuration : dict, optional
Configuration overrides. Default: instrument defaults
trace_range : tuple, optional
(first, last+1) orders to process
plot : int
Plot level (0=off, 1=basic, 2=detailed)
plot_dir : str, optional
Directory to save plots. If None, shows interactively.
Returns
-------
Pipeline
Configured pipeline ready to call .run()
Example
-------
>>> result = Pipeline.from_instrument(
... instrument="UVES",
... target="HD132205",
... night="2010-04-01",
... channel="middle",
... steps=("bias", "flat", "trace", "science"),
... ).run()
"""
# Environment variable overrides for plot
if "PYREDUCE_PLOT" in os.environ:
plot = int(os.environ["PYREDUCE_PLOT"])
if "PYREDUCE_PLOT_DIR" in os.environ:
plot_dir = os.environ["PYREDUCE_PLOT_DIR"]
plot_show = os.environ.get("PYREDUCE_PLOT_SHOW", "block")
# Set global plot settings
util.set_plot_dir(plot_dir)
util.set_plot_show(plot_show, plot_level=plot)
# Load instrument (before config, so we can get settings fallbacks)
inst = load_instrument(instrument)
# Load configuration (channel-specific if settings_{channel}.json exists)
channel_fallbacks = inst.get_settings_fallbacks(channel) if channel else None
config = load_config(
configuration,
instrument,
0,
channel=channel,
channel_fallbacks=channel_fallbacks,
)
info = inst.info
# Get directories from config if not specified
if base_dir is None:
base_dir = config["reduce"]["base_dir"]
if input_dir is None:
input_dir = config["reduce"]["input_dir"]
if output_dir is None:
output_dir = config["reduce"]["output_dir"]
full_input_dir = join(base_dir, input_dir)
full_output_dir = join(base_dir, output_dir)
# Get channels to process
if channel is None:
channels = info["channels"]
else:
channels = [channel] if isinstance(channel, str) else channel
# Normalize steps for file sorting warnings
steps_list = list(cls.STEP_ORDER.keys()) if steps == "all" else list(steps)
# Find and sort files
files = inst.sort_files(
full_input_dir,
target,
night,
channel=channels[0] if len(channels) == 1 else channels[0],
steps=steps_list,
**config["instrument"],
)
if len(files) == 0:
logger.warning(
"No files found for instrument: %s, target: %s, night: %s, channel: %s",
instrument,
target,
night,
channel,
)
raise FileNotFoundError(
f"No files found for {instrument} / {target} / {night} / {channel}"
)
# Use the first file set (for single channel)
k, f = files[0]
logger.info("Pipeline settings:")
for key, value in k.items():
logger.info(" %s: %s", key, value)
# Create pipeline
pipe = cls.from_files(
files=f,
output_dir=full_output_dir,
target=k.get("target", target),
instrument=inst,
channel=channels[0],
night=k.get("night", night or ""),
config=config,
trace_range=trace_range,
steps=steps,
plot=plot,
plot_dir=plot_dir,
)
return pipe