"""
Abstract parent module for all other instruments
Contains some general functionality, which may be overridden by the children of course
"""
import datetime
import glob
import json
import logging
import os.path
from itertools import product
import numpy as np
import yaml
from astropy.io import fits
from astropy.time import Time
from dateutil import parser
from tqdm import tqdm
from ..clipnflip import clipnflip
from .filters import ChannelFilter, Filter, InstrumentFilter, NightFilter, ObjectFilter
from .models import InstrumentConfig
logger = logging.getLogger(__name__)
[docs]
def parse_iraf_section(section_str):
"""Parse IRAF-style section string into pixel coordinates.
IRAF format: [x1:x2,y1:y2] where coordinates are 1-based inclusive.
Parameters
----------
section_str : str
Section string like "[28:1179,1:4616]"
Returns
-------
x1, x2, y1, y2 : int
1-based inclusive pixel coordinates
"""
# Remove brackets and split
s = section_str.strip("[]")
x_part, y_part = s.split(",")
x1, x2 = map(int, x_part.split(":"))
y1, y2 = map(int, y_part.split(":"))
return x1, x2, y1, y2
[docs]
def find_first_index(arr, value):
"""find the first element equal to value in the array arr"""
try:
return next(i for i, v in enumerate(arr) if v == value)
except StopIteration as e:
raise KeyError(f"Value {value} not found") from e
[docs]
def observation_date_to_night(observation_date):
"""Convert an observation timestamp into the date of the observation night
Nights start at 12am and end at 12 am the next day
Parameters
----------
observation_date : datetime
timestamp of the observation
Returns
-------
night : datetime.date
night of the observation
"""
if observation_date == "":
return None
observation_date = parser.parse(observation_date)
oneday = datetime.timedelta(days=1)
if observation_date.hour < 12:
observation_date -= oneday
return observation_date.date()
[docs]
class getter:
"""Get data from a header/dict, based on the given channel, and applies replacements"""
def __init__(self, header, info, channel):
self.header = header
self.info = info.copy()
try:
self.index = find_first_index(info["channels"], channel.upper())
except KeyError:
logger.warning("No instrument channels found in instrument info")
self.index = 0
# Pick values for the given channel
for k, v in self.info.items():
if isinstance(v, list):
self.info[k] = v[self.index]
def __call__(self, key, alt=None):
return self.get(key, alt)
[docs]
def get(self, key, alt=None):
"""Get data
Parameters
----------
key : str
key of the data in the header
alt : obj, optional
alternative value, if key does not exist (default: None)
Returns
-------
value : obj
value found in header (or alternatively alt)
"""
value = self.info.get(key, key)
# if isinstance(value, list):
# value = value[self.index]
if isinstance(value, str):
value = value.format(**self.info)
value = self.header.get(value, alt)
return value
[docs]
class Instrument:
"""
Abstract parent class for all instruments
Handles the instrument specific information
"""
def __init__(self):
#:str: Name of the instrument (lowercase)
self.name = self.__class__.__name__.lower()
#:InstrumentConfig: Validated configuration model
#:dict: Information about the instrument (for backward compatibility)
self.config, self.info = self.load_info()
self.filters = {
"instrument": InstrumentFilter(self.config.instrument, regex=True),
"night": NightFilter(self.config.date, timeformat=self.config.date_format),
"target": ObjectFilter(self.config.target, regex=True),
"bias": Filter(self.config.kw_bias),
"flat": Filter(self.config.kw_flat),
"trace": Filter(self.config.kw_orders),
"curvature": Filter(self.config.kw_curvature),
"scatter": Filter(self.config.kw_scatter),
"wave": Filter(self.config.kw_wave),
"comb": Filter(self.config.kw_comb),
"spec": Filter(self.config.kw_spec),
}
self.night = "night"
self.science = "science"
self.shared = ["instrument", "night"]
# Directory containing instrument config files
self._inst_dir = os.path.join(os.path.dirname(__file__), self.name.upper())
# Add channel filter if kw_channel is defined (for instruments with separate files per channel)
if self.config.kw_channel is not None:
self.filters["channel"] = ChannelFilter(self.config.kw_channel)
self.shared.append("channel")
self.find_closest = [
"bias",
"flat",
"wavecal_master",
"freq_comb_master",
"trace",
"scatter",
"curvature",
]
def __str__(self):
return self.name
@property
def channels(self) -> list[str] | None:
"""Available instrument channels."""
return self.config.channels
@property
def extension(self) -> int | str | list:
"""FITS extension(s) to read."""
return self.config.extension
@property
def orientation(self) -> int | list[int]:
"""Detector orientation code(s)."""
return self.config.orientation
@property
def id_instrument(self) -> str:
"""Instrument identifier for header matching."""
return self.config.id_instrument
[docs]
def get(self, key, header, channel, alt=None):
get = getter(header, self.info, channel)
return get(key, alt=alt)
[docs]
def get_extension(self, header, channel):
channel = channel.upper()
ext = self.extension # Use property
if isinstance(ext, list):
ichannel = find_first_index(self.channels, channel)
ext = ext[ichannel]
return ext
[docs]
def load_info(self):
"""
Load static instrument information
Either as fits header keywords or static values
Returns
------
config : InstrumentConfig
Validated Pydantic model
info : dict(str:object)
dictionary of REDUCE names for properties to Header keywords/static values
"""
# Tips & Tricks:
# if several channels are supported, use a list for channels
# if a value changes depending on the channel, use a list with the same order as "channels"
# you can also use values from this dictionary as placeholders using {name}, just like str.format
this = os.path.dirname(__file__)
inst_dir = os.path.join(this, self.name.upper())
yaml_fname = os.path.join(inst_dir, "config.yaml")
json_fname = os.path.join(inst_dir, "config.json")
if os.path.exists(yaml_fname):
with open(yaml_fname) as f:
info = yaml.safe_load(f)
elif os.path.exists(json_fname):
with open(json_fname) as f:
info = json.load(f)
else:
raise FileNotFoundError(
f"No instrument config found for {self.name} "
f"(tried {yaml_fname} and {json_fname})"
)
# Validate with Pydantic (strict - invalid config is a bug)
config = InstrumentConfig(**info)
return config, info
[docs]
def get_amplifier_extensions(self, header):
"""Get list of amplifier extensions if multi-amp mode is configured.
Parameters
----------
header : fits.Header
Primary FITS header
Returns
-------
list or None
List of extension names/indices if multi-amp, None otherwise
"""
if self.config.amplifiers is None:
return None
amp_config = self.config.amplifiers
# Get number of amplifiers
if isinstance(amp_config.count, int):
n_amps = amp_config.count
else:
n_amps = header.get(amp_config.count)
if n_amps is None:
logger.warning(
"Amplifier count keyword '%s' not found in header",
amp_config.count,
)
return None
# Generate extension names from template
extensions = []
for n in range(1, n_amps + 1):
ext_name = amp_config.extension_template.format(n=n)
extensions.append(ext_name)
return extensions
[docs]
def assemble_amplifiers(self, hdu, amp_extensions, channel):
"""Assemble multi-amplifier readout into single frame.
Reads data from multiple FITS extensions, each representing one
amplifier's readout region, and assembles them into a single frame
using DATASEC/DETSEC mappings.
Parameters
----------
hdu : HDUList
Open FITS file
amp_extensions : list
List of extension names to read
channel : str
Instrument channel
Returns
-------
data : ndarray
Assembled frame (float32)
header : fits.Header
Combined header with per-amplifier e_gain{n}, e_readn{n}
"""
amp_config = self.config.amplifiers
h_prime = hdu[0].header
# First pass: determine output size and collect amp info
xmax, ymax = 0, 0
amp_info = []
for ext in amp_extensions:
h = hdu[ext].header
datasec = parse_iraf_section(h[amp_config.datasec])
detsec = parse_iraf_section(h[amp_config.detsec])
# DETSEC gives the destination in the assembled image
xmax = max(xmax, detsec[1])
ymax = max(ymax, detsec[3])
amp_info.append(
{
"ext": ext,
"datasec": datasec,
"detsec": detsec,
"gain": h.get(amp_config.gain, 1.0),
"readnoise": h.get(amp_config.readnoise, 0.0),
}
)
# Allocate output array
assembled = np.zeros((ymax, xmax), dtype=np.float32)
# Second pass: place each amplifier's data
for info in amp_info:
# Extract valid data region (DATASEC is 1-based inclusive)
dx1, dx2, dy1, dy2 = info["datasec"]
raw = hdu[info["ext"]].data[dy1 - 1 : dy2, dx1 - 1 : dx2]
# Place into output at DETSEC location (1-based inclusive)
ox1, ox2, oy1, oy2 = info["detsec"]
assembled[oy1 - 1 : oy2, ox1 - 1 : ox2] = raw
# Build combined header
header = hdu[amp_extensions[0]].header.copy()
header.extend(h_prime, strip=False)
# Add standard header info first (sets e_orient, e_transpose, etc.)
header = self.add_header_info(header, channel)
header["e_input"] = (
os.path.basename(hdu.filename()),
"Original input filename",
)
# Store per-amplifier calibration values
# Use e_namps (not e_ampl) to avoid triggering clipnflip's multi-amp path
# since we've already assembled the frame
header["e_namps"] = (len(amp_info), "Number of amplifiers in raw data")
for i, info in enumerate(amp_info, 1):
header[f"HIERARCH e_gain{i}"] = (
info["gain"],
f"Gain for amplifier {i}",
)
header[f"HIERARCH e_readn{i}"] = (
info["readnoise"],
f"Readnoise for amplifier {i}",
)
# Override e_gain/e_readn with median across amplifiers
gains = [info["gain"] for info in amp_info]
readnoises = [info["readnoise"] for info in amp_info]
header["e_gain"] = (float(np.median(gains)), "Median gain across amplifiers")
header["e_readn"] = (
float(np.median(readnoises)),
"Median readnoise across amplifiers",
)
# Override bounds for assembled frame (full frame, no prescan/overscan)
header["e_xlo"] = 0
header["e_xhi"] = xmax
header["e_ylo"] = 0
header["e_yhi"] = ymax
return assembled, header
[docs]
def load_fits(
self, fname, channel, extension=None, mask=None, header_only=False, dtype=None
):
"""
load fits file, REDUCE style
primary and extension header are combined
channel-specific info is applied to header
data is clipnflipped
mask is applied
For multi-amplifier instruments, data from multiple extensions
is assembled into a single frame before clipnflip.
Parameters
----------
fname : str
filename
instrument : str
name of the instrument
channel : str
instrument channel
extension : int
data extension of the FITS file to load
mask : array, optional
mask to add to the data
header_only : bool, optional
only load the header, not the data
dtype : str, optional
numpy datatype to convert the read data to
Returns
--------
data : masked_array
FITS data, clipped and flipped, and with mask
header : fits.header
FITS header (Primary and Extension + channel info)
ONLY the header is returned if header_only is True
"""
channel = channel.upper()
hdu = fits.open(fname)
h_prime = hdu[0].header
# Check for multi-amplifier mode
amp_extensions = self.get_amplifier_extensions(h_prime)
if amp_extensions is not None:
# Multi-amplifier path: assemble from multiple extensions
if header_only:
# For header_only, just return first extension header + primary
header = hdu[amp_extensions[0]].header.copy()
header.extend(h_prime, strip=False)
header = self.add_header_info(header, channel)
header["e_input"] = (os.path.basename(fname), "Original input filename")
hdu.close()
return header
data, header = self.assemble_amplifiers(hdu, amp_extensions, channel)
try:
data = clipnflip(data, header)
except IndexError as e:
hdu.close()
raise ValueError(
f"Failed to load {fname} for channel '{channel}': {e}\n"
f"This usually means the file does not contain data for this channel."
) from None
else:
# Single extension path (original behavior)
if extension is None:
extension = self.get_extension(h_prime, channel)
header = hdu[extension].header
if extension != 0:
header.extend(h_prime, strip=False)
header = self.add_header_info(header, channel)
header["e_input"] = (os.path.basename(fname), "Original input filename")
if header_only:
hdu.close()
return header
try:
data = clipnflip(hdu[extension].data, header)
except IndexError as e:
hdu.close()
raise ValueError(
f"Failed to load {fname} for channel '{channel}' (extension {extension}): {e}\n"
f"This usually means the file does not contain data for this channel."
) from None
if dtype is not None:
data = data.astype(dtype)
data = np.ma.masked_array(data, mask=mask)
hdu.close()
return data, header
[docs]
def find_files(self, input_dir):
"""Find fits files in the given folder
Parameters
----------
input_dir : string
directory to look for fits and fits.gz files in, may include bash style wildcards
Returns
-------
files: array(string)
absolute path filenames
"""
files = glob.glob(input_dir + "/*.fits")
files += glob.glob(input_dir + "/*.fits.gz")
files = np.array(files)
return files
[docs]
def get_expected_values(self, target, night, channel=None, **kwargs):
expectations = {
"bias": {
"instrument": self.config.id_instrument,
"night": night,
"bias": self.config.id_bias,
},
"flat": {
"instrument": self.config.id_instrument,
"night": night,
"flat": self.config.id_flat,
},
"trace": {
"instrument": self.config.id_instrument,
"night": night,
"trace": self.config.id_orders,
},
"scatter": {
"instrument": self.config.id_instrument,
"night": night,
"scatter": self.config.id_scatter,
},
"curvature": {
"instrument": self.config.id_instrument,
"night": night,
"curvature": self.config.id_curvature,
},
"wavecal_master": {
"instrument": self.config.id_instrument,
"night": night,
"wave": self.config.id_wave,
},
"freq_comb_master": {
"instrument": self.config.id_instrument,
"night": night,
"comb": self.config.id_comb,
},
"science": {
"instrument": self.config.id_instrument,
"night": night,
"target": target,
"spec": self.config.id_spec,
},
}
# Add channel filter if this instrument has separate files per channel
if channel is not None and self.config.kw_channel is not None:
id_channel = self.config.id_channel
channels = self.config.channels
channel_id = (
id_channel[channels.index(channel)] if channel in channels else channel
)
for key in expectations:
expectations[key]["channel"] = channel_id
return expectations
[docs]
def populate_filters(self, files):
"""Extract values from the fits headers and store them in self.filters
Parameters
----------
files : list(str)
list of fits files
Returns
-------
filters: list(Filter)
list of populated filters (identical to self.filters)
"""
# Empty filters
for _, fil in self.filters.items():
fil.clear()
for f in tqdm(files):
with fits.open(f) as hdu:
h = hdu[0].header
for _, fil in self.filters.items():
fil.collect(h)
return self.filters
[docs]
def apply_filters(self, files, expected, steps=None):
"""
Determine the relevant files for a given set of expected values.
Parameters
----------
files : list(files)
list if fits files
expected : dict
dictionary with expected header values for each reduction step
steps : list, optional
list of steps that will be run. If provided, warnings about
missing files are only shown for these steps.
Returns
-------
files: list((dict, dict))
list of files. The first element of each tuple is the used setting,
and the second are the files for each step.
"""
# Fill the filters with header information
self.populate_filters(files)
# Use the header information determined in populate filters
# to find potential science and calibration files in the list of files
# result = {step : [ {setting : value}, [files] ] }
result = {}
for step, values in expected.items():
result[step] = []
data = {}
for name, value in values.items():
# For 'night', don't filter during classification
# - get all nights so find_closest can work
if name == self.night:
value = None
if isinstance(value, list):
for v in value:
data[name] = self.filters[name].classify(v)
if len(data[name]) > 0:
break
else:
data[name] = self.filters[name].classify(value)
# Get all combinations of possible filter values
# e.g. if several nights are allowed
for thingy in product(*data.values()):
mask = np.copy(thingy[0][1])
for i in range(1, len(thingy)):
mask &= thingy[i][1]
if np.count_nonzero(mask) == 0:
continue
d = {k: v[0] for k, v in zip(values.keys(), thingy, strict=False)}
f = files[mask]
result[step].append((d, f))
# files = [{setting: value}, {step: files}]
files = []
settings = {}
for shared in self.shared:
# Check if user specified a value for this shared parameter
sample_expected = expected.get(self.science, {}).get(shared)
if sample_expected is not None and sample_expected != "":
# User specified a value - use only matching values from data
keys = self.filters[shared].classify(sample_expected)
keys = [k for k, _ in keys if k is not None]
else:
# No filter - use all unique values
keys = [k for k in set(self.filters[shared].data) if k is not None]
settings[shared] = keys
values = [settings[k] for k in self.shared]
for setting in product(*values):
setting = dict(zip(self.shared, setting, strict=False))
night = setting[self.night]
f = {}
# For each step look for files with matching settings
for step, step_data in result.items():
f[step] = []
for step_key, step_files in step_data:
match = [
setting[shared] == step_key[shared]
for shared in self.shared
if shared in step_key.keys()
]
if all(match):
f[step] = step_files
break
# If no matching files are found ...
if len(f[step]) == 0:
if step not in self.find_closest:
# Show a warning (only for requested steps)
if steps is None or step in steps:
logger.warning(
"Could not find any files for step '%s' with settings %s, sharing parameters %s",
step,
setting,
self.shared,
)
else:
# Or find the closest night instead
j = None
for i, (step_key, _) in enumerate(step_data):
match = [
setting[shared] == step_key[shared]
for shared in self.shared
if shared in step_key.keys() and shared != self.night
]
if all(match):
if j is None:
j = i
else:
diff_old = abs(step_data[j][0][self.night] - night)
diff_new = abs(step_data[i][0][self.night] - night)
if diff_new < diff_old:
j = i
if j is None:
# We still dont find any files (only warn for requested steps)
if steps is None or step in steps:
logger.warning(
"Could not find any files for step '%s' in any night with settings %s, sharing parameters %s",
step,
setting,
self.shared,
)
else:
# We found files in a close night
closest_key, closest_files = step_data[j]
logger.warning(
"Using '%s' files from night %s for observations of night %s",
step,
night,
closest_key["night"],
)
f[step] = closest_files
if any(len(a) > 0 for a in f.values()):
files.append((setting, f))
if len(files) == 0:
logger.warning(
"No %s files found matching the expected values %s",
self.science,
expected[self.science],
)
return files
[docs]
def sort_files(self, input_dir, target, night, *args, steps=None, **kwargs):
"""
Sort a set of fits files into different categories
types are: bias, flat, wavecal, orderdef, spec
Parameters
----------
input_dir : str
input directory containing the files to sort
target : str
name of the target as in the fits headers
night : str
observation night, possibly with wildcards
channel : str
instrument channel
steps : list, optional
list of steps that will be run. If provided, warnings about
missing files are only shown for these steps.
Returns
-------
files_per_night : list[dict{str:dict{str:list[str]}}]
a list of file sets, one entry per night, where each night consists of a dictionary with one entry per setting,
each fileset has five lists of filenames: "bias", "flat", "order", "wave", "spec", organised in another dict
nights_out : list[datetime]
a list of observation times, same order as files_per_night
"""
input_dir = input_dir.format(
**kwargs, target=target, night=night, instrument=self.name
)
files = self.find_files(input_dir)
ev = self.get_expected_values(target, night, *args, **kwargs)
files = self.apply_filters(files, ev, steps=steps)
return files
[docs]
def discover_channels(self, input_dir):
"""Discover available channels from files in the input directory.
Override in subclasses for instruments that require channel specification.
The base implementation returns [None] which means no channel filtering.
Parameters
----------
input_dir : str
Directory containing input files
Returns
-------
channels : list
List of channel identifiers found in the data
"""
return [None]
[docs]
def get_wavecal_filename(self, header, channel, **kwargs):
"""Get the filename of the pre-existing wavelength solution for the current setting
Parameters
----------
header : fits.header, dict
header of the wavelength calibration file
channel : str
instrument channel
Returns
-------
filename : str
name of the wavelength solution file
"""
specifier = header.get(self.config.wavecal_specifier or "", "")
cwd = os.path.dirname(__file__)
fname = f"wavecal_{channel}_{specifier}.npz"
fname = os.path.join(cwd, self.name.upper(), fname)
return fname
[docs]
def get_supported_channels(self):
return self.channels
[docs]
def get_settings_fallbacks(self, channel):
"""Return channel names to try when looking up settings files.
Searched in order: most specific first, least specific last.
The base settings.json is always the final fallback (handled by the caller).
Override in subclasses for instruments with composite channel names.
"""
return [channel] if channel else []
[docs]
def get_mask_filename(self, channel, **kwargs):
c = channel.lower() if channel else ""
fname = f"mask_{c}.fits.gz" if c else "mask.fits.gz"
cwd = os.path.dirname(__file__)
fname = os.path.join(cwd, self.name.upper(), fname)
return fname
[docs]
def get_wavelength_range(self, header, channel, **kwargs):
return self.get("wavelength_range", header, channel)
[docs]
class COMMON(Instrument):
[docs]
def load_info(self):
"""Load the default/common instrument config from defaults/config.yaml"""
this = os.path.dirname(__file__)
yaml_fname = os.path.join(this, "defaults", "config.yaml")
with open(yaml_fname) as f:
info = yaml.safe_load(f)
config = InstrumentConfig(**info)
return config, info
[docs]
def create_custom_instrument(
name, extension=0, info=None, mask_file=None, wavecal_file=None, **overrides
):
"""Create a custom instrument for reducing data from unsupported spectrographs.
Parameters
----------
name : str
Instrument name (used for output file naming)
extension : int or str
FITS extension to read data from (default: 0)
info : dict or str, optional
Instrument properties as a dict, or path to a config file (JSON/YAML).
If None, uses sensible defaults. Dict values can be literal numbers
or FITS header keyword strings (looked up at runtime).
mask_file : str, optional
Path to bad pixel mask FITS file
wavecal_file : str, optional
Path to wavelength calibration file
**overrides
Additional instrument properties, e.g. ``gain=1.1``, ``readnoise=5``,
``orientation=4``. These override values from the info dict.
See ``instruments/defaults/config.yaml`` for available keys.
Returns
-------
Instrument
Configured instrument instance
"""
class CUSTOM(Instrument):
def __init__(self):
super().__init__()
self.name = name
def load_info(self):
base_info = COMMON().info.copy()
if info is not None:
if isinstance(info, dict):
base_info.update(info)
elif isinstance(info, str):
with open(info) as f:
if info.endswith((".yaml", ".yml")):
base_info.update(yaml.safe_load(f))
else:
base_info.update(json.load(f))
base_info["extension"] = extension
base_info.update(overrides)
config = InstrumentConfig(**base_info)
return config, base_info
def get_extension(self, header, channel):
return extension
def get_mask_filename(self, channel, **kwargs):
return mask_file
def get_wavecal_filename(self, header, channel, **kwargs):
return wavecal_file
return CUSTOM()