Source code for pyreduce.instruments.common

"""
Abstract parent module for all other instruments
Contains some general functionality, which may be overridden by the children of course
"""

import datetime
import glob
import json
import logging
import os.path
from itertools import product

import numpy as np
import yaml
from astropy.io import fits
from astropy.time import Time
from dateutil import parser
from tqdm import tqdm

from ..clipnflip import clipnflip
from .filters import ChannelFilter, Filter, InstrumentFilter, NightFilter, ObjectFilter
from .models import InstrumentConfig

logger = logging.getLogger(__name__)


[docs] def parse_iraf_section(section_str): """Parse IRAF-style section string into pixel coordinates. IRAF format: [x1:x2,y1:y2] where coordinates are 1-based inclusive. Parameters ---------- section_str : str Section string like "[28:1179,1:4616]" Returns ------- x1, x2, y1, y2 : int 1-based inclusive pixel coordinates """ # Remove brackets and split s = section_str.strip("[]") x_part, y_part = s.split(",") x1, x2 = map(int, x_part.split(":")) y1, y2 = map(int, y_part.split(":")) return x1, x2, y1, y2
[docs] def find_first_index(arr, value): """find the first element equal to value in the array arr""" try: return next(i for i, v in enumerate(arr) if v == value) except StopIteration as e: raise KeyError(f"Value {value} not found") from e
[docs] def observation_date_to_night(observation_date): """Convert an observation timestamp into the date of the observation night Nights start at 12am and end at 12 am the next day Parameters ---------- observation_date : datetime timestamp of the observation Returns ------- night : datetime.date night of the observation """ if observation_date == "": return None observation_date = parser.parse(observation_date) oneday = datetime.timedelta(days=1) if observation_date.hour < 12: observation_date -= oneday return observation_date.date()
[docs] class getter: """Get data from a header/dict, based on the given channel, and applies replacements""" def __init__(self, header, info, channel): self.header = header self.info = info.copy() try: self.index = find_first_index(info["channels"], channel.upper()) except KeyError: logger.warning("No instrument channels found in instrument info") self.index = 0 # Pick values for the given channel for k, v in self.info.items(): if isinstance(v, list): self.info[k] = v[self.index] def __call__(self, key, alt=None): return self.get(key, alt)
[docs] def get(self, key, alt=None): """Get data Parameters ---------- key : str key of the data in the header alt : obj, optional alternative value, if key does not exist (default: None) Returns ------- value : obj value found in header (or alternatively alt) """ value = self.info.get(key, key) # if isinstance(value, list): # value = value[self.index] if isinstance(value, str): value = value.format(**self.info) value = self.header.get(value, alt) return value
[docs] class Instrument: """ Abstract parent class for all instruments Handles the instrument specific information """ def __init__(self): #:str: Name of the instrument (lowercase) self.name = self.__class__.__name__.lower() #:InstrumentConfig: Validated configuration model #:dict: Information about the instrument (for backward compatibility) self.config, self.info = self.load_info() self.filters = { "instrument": InstrumentFilter(self.config.instrument, regex=True), "night": NightFilter(self.config.date, timeformat=self.config.date_format), "target": ObjectFilter(self.config.target, regex=True), "bias": Filter(self.config.kw_bias), "flat": Filter(self.config.kw_flat), "trace": Filter(self.config.kw_orders), "curvature": Filter(self.config.kw_curvature), "scatter": Filter(self.config.kw_scatter), "wave": Filter(self.config.kw_wave), "comb": Filter(self.config.kw_comb), "spec": Filter(self.config.kw_spec), } self.night = "night" self.science = "science" self.shared = ["instrument", "night"] # Directory containing instrument config files self._inst_dir = os.path.join(os.path.dirname(__file__), self.name.upper()) # Add channel filter if kw_channel is defined (for instruments with separate files per channel) if self.config.kw_channel is not None: self.filters["channel"] = ChannelFilter(self.config.kw_channel) self.shared.append("channel") self.find_closest = [ "bias", "flat", "wavecal_master", "freq_comb_master", "trace", "scatter", "curvature", ] def __str__(self): return self.name @property def channels(self) -> list[str] | None: """Available instrument channels.""" return self.config.channels @property def extension(self) -> int | str | list: """FITS extension(s) to read.""" return self.config.extension @property def orientation(self) -> int | list[int]: """Detector orientation code(s).""" return self.config.orientation @property def id_instrument(self) -> str: """Instrument identifier for header matching.""" return self.config.id_instrument
[docs] def get(self, key, header, channel, alt=None): get = getter(header, self.info, channel) return get(key, alt=alt)
[docs] def get_extension(self, header, channel): channel = channel.upper() ext = self.extension # Use property if isinstance(ext, list): ichannel = find_first_index(self.channels, channel) ext = ext[ichannel] return ext
[docs] def load_info(self): """ Load static instrument information Either as fits header keywords or static values Returns ------ config : InstrumentConfig Validated Pydantic model info : dict(str:object) dictionary of REDUCE names for properties to Header keywords/static values """ # Tips & Tricks: # if several channels are supported, use a list for channels # if a value changes depending on the channel, use a list with the same order as "channels" # you can also use values from this dictionary as placeholders using {name}, just like str.format this = os.path.dirname(__file__) inst_dir = os.path.join(this, self.name.upper()) yaml_fname = os.path.join(inst_dir, "config.yaml") json_fname = os.path.join(inst_dir, "config.json") if os.path.exists(yaml_fname): with open(yaml_fname) as f: info = yaml.safe_load(f) elif os.path.exists(json_fname): with open(json_fname) as f: info = json.load(f) else: raise FileNotFoundError( f"No instrument config found for {self.name} " f"(tried {yaml_fname} and {json_fname})" ) # Validate with Pydantic (strict - invalid config is a bug) config = InstrumentConfig(**info) return config, info
[docs] def get_amplifier_extensions(self, header): """Get list of amplifier extensions if multi-amp mode is configured. Parameters ---------- header : fits.Header Primary FITS header Returns ------- list or None List of extension names/indices if multi-amp, None otherwise """ if self.config.amplifiers is None: return None amp_config = self.config.amplifiers # Get number of amplifiers if isinstance(amp_config.count, int): n_amps = amp_config.count else: n_amps = header.get(amp_config.count) if n_amps is None: logger.warning( "Amplifier count keyword '%s' not found in header", amp_config.count, ) return None # Generate extension names from template extensions = [] for n in range(1, n_amps + 1): ext_name = amp_config.extension_template.format(n=n) extensions.append(ext_name) return extensions
[docs] def assemble_amplifiers(self, hdu, amp_extensions, channel): """Assemble multi-amplifier readout into single frame. Reads data from multiple FITS extensions, each representing one amplifier's readout region, and assembles them into a single frame using DATASEC/DETSEC mappings. Parameters ---------- hdu : HDUList Open FITS file amp_extensions : list List of extension names to read channel : str Instrument channel Returns ------- data : ndarray Assembled frame (float32) header : fits.Header Combined header with per-amplifier e_gain{n}, e_readn{n} """ amp_config = self.config.amplifiers h_prime = hdu[0].header # First pass: determine output size and collect amp info xmax, ymax = 0, 0 amp_info = [] for ext in amp_extensions: h = hdu[ext].header datasec = parse_iraf_section(h[amp_config.datasec]) detsec = parse_iraf_section(h[amp_config.detsec]) # DETSEC gives the destination in the assembled image xmax = max(xmax, detsec[1]) ymax = max(ymax, detsec[3]) amp_info.append( { "ext": ext, "datasec": datasec, "detsec": detsec, "gain": h.get(amp_config.gain, 1.0), "readnoise": h.get(amp_config.readnoise, 0.0), } ) # Allocate output array assembled = np.zeros((ymax, xmax), dtype=np.float32) # Second pass: place each amplifier's data for info in amp_info: # Extract valid data region (DATASEC is 1-based inclusive) dx1, dx2, dy1, dy2 = info["datasec"] raw = hdu[info["ext"]].data[dy1 - 1 : dy2, dx1 - 1 : dx2] # Place into output at DETSEC location (1-based inclusive) ox1, ox2, oy1, oy2 = info["detsec"] assembled[oy1 - 1 : oy2, ox1 - 1 : ox2] = raw # Build combined header header = hdu[amp_extensions[0]].header.copy() header.extend(h_prime, strip=False) # Add standard header info first (sets e_orient, e_transpose, etc.) header = self.add_header_info(header, channel) header["e_input"] = ( os.path.basename(hdu.filename()), "Original input filename", ) # Store per-amplifier calibration values # Use e_namps (not e_ampl) to avoid triggering clipnflip's multi-amp path # since we've already assembled the frame header["e_namps"] = (len(amp_info), "Number of amplifiers in raw data") for i, info in enumerate(amp_info, 1): header[f"HIERARCH e_gain{i}"] = ( info["gain"], f"Gain for amplifier {i}", ) header[f"HIERARCH e_readn{i}"] = ( info["readnoise"], f"Readnoise for amplifier {i}", ) # Override e_gain/e_readn with median across amplifiers gains = [info["gain"] for info in amp_info] readnoises = [info["readnoise"] for info in amp_info] header["e_gain"] = (float(np.median(gains)), "Median gain across amplifiers") header["e_readn"] = ( float(np.median(readnoises)), "Median readnoise across amplifiers", ) # Override bounds for assembled frame (full frame, no prescan/overscan) header["e_xlo"] = 0 header["e_xhi"] = xmax header["e_ylo"] = 0 header["e_yhi"] = ymax return assembled, header
[docs] def load_fits( self, fname, channel, extension=None, mask=None, header_only=False, dtype=None ): """ load fits file, REDUCE style primary and extension header are combined channel-specific info is applied to header data is clipnflipped mask is applied For multi-amplifier instruments, data from multiple extensions is assembled into a single frame before clipnflip. Parameters ---------- fname : str filename instrument : str name of the instrument channel : str instrument channel extension : int data extension of the FITS file to load mask : array, optional mask to add to the data header_only : bool, optional only load the header, not the data dtype : str, optional numpy datatype to convert the read data to Returns -------- data : masked_array FITS data, clipped and flipped, and with mask header : fits.header FITS header (Primary and Extension + channel info) ONLY the header is returned if header_only is True """ channel = channel.upper() hdu = fits.open(fname) h_prime = hdu[0].header # Check for multi-amplifier mode amp_extensions = self.get_amplifier_extensions(h_prime) if amp_extensions is not None: # Multi-amplifier path: assemble from multiple extensions if header_only: # For header_only, just return first extension header + primary header = hdu[amp_extensions[0]].header.copy() header.extend(h_prime, strip=False) header = self.add_header_info(header, channel) header["e_input"] = (os.path.basename(fname), "Original input filename") hdu.close() return header data, header = self.assemble_amplifiers(hdu, amp_extensions, channel) try: data = clipnflip(data, header) except IndexError as e: hdu.close() raise ValueError( f"Failed to load {fname} for channel '{channel}': {e}\n" f"This usually means the file does not contain data for this channel." ) from None else: # Single extension path (original behavior) if extension is None: extension = self.get_extension(h_prime, channel) header = hdu[extension].header if extension != 0: header.extend(h_prime, strip=False) header = self.add_header_info(header, channel) header["e_input"] = (os.path.basename(fname), "Original input filename") if header_only: hdu.close() return header try: data = clipnflip(hdu[extension].data, header) except IndexError as e: hdu.close() raise ValueError( f"Failed to load {fname} for channel '{channel}' (extension {extension}): {e}\n" f"This usually means the file does not contain data for this channel." ) from None if dtype is not None: data = data.astype(dtype) data = np.ma.masked_array(data, mask=mask) hdu.close() return data, header
[docs] def add_header_info(self, header, channel, **kwargs): """read data from header and add it as REDUCE keyword back to the header Parameters ---------- header : fits.header, dict header to read/write info from/to channel : str instrument channel Returns ------- header : fits.header, dict header with added information """ info = self.info get = getter(header, info, channel) # Use HIERARCH prefix only for FITS Header objects to avoid warnings # For dict objects, HIERARCH is not needed and would break key access from astropy.io.fits import Header as FitsHeader hierarch = "HIERARCH " if isinstance(header, FitsHeader) else "" header[f"{hierarch}e_instrument"] = get("instrument", self.__class__.__name__) header[f"{hierarch}e_telescope"] = get("telescope", "") header[f"{hierarch}e_exptime"] = get("exposure_time", 0) jd = get("date") if jd is not None: jd = Time(jd, format=self.info.get("date_format", "fits")) jd = jd.to_value("mjd") header["e_orient"] = get("orientation", 0) # As per IDL rotate if orient is 4 or larger and transpose is undefined # the image is transposed header[f"{hierarch}e_transpose"] = get( "transpose", (header["e_orient"] % 8 >= 4) ) naxis_x = get("naxis_x", 0) naxis_y = get("naxis_y", 0) prescan_x = get("prescan_x", 0) overscan_x = get("overscan_x", 0) prescan_y = get("prescan_y", 0) overscan_y = get("overscan_y", 0) header["e_xlo"] = prescan_x header["e_xhi"] = naxis_x - overscan_x header["e_ylo"] = prescan_y header["e_yhi"] = naxis_y - overscan_y header["e_gain"] = get("gain", 1) header["e_readn"] = get("readnoise", 0) header["e_sky"] = get("sky", 0) header["e_drk"] = get("dark", 0) header["e_backg"] = header["e_gain"] * (header["e_drk"] + header["e_sky"]) header["e_imtype"] = get("image_type") header["e_ctg"] = get("category") header["e_ra"] = get("ra", 0) header["e_dec"] = get("dec", 0) header["e_jd"] = jd header["e_obslon"] = get("longitude") header["e_obslat"] = get("latitude") header["e_obsalt"] = get("altitude") if info.get("wavecal_element", None) is not None: header["HIERARCH e_wavecal_element"] = get( "wavecal_element", info.get("wavecal_element", None) ) return header
[docs] def find_files(self, input_dir): """Find fits files in the given folder Parameters ---------- input_dir : string directory to look for fits and fits.gz files in, may include bash style wildcards Returns ------- files: array(string) absolute path filenames """ files = glob.glob(input_dir + "/*.fits") files += glob.glob(input_dir + "/*.fits.gz") files = np.array(files) return files
[docs] def get_expected_values(self, target, night, channel=None, **kwargs): expectations = { "bias": { "instrument": self.config.id_instrument, "night": night, "bias": self.config.id_bias, }, "flat": { "instrument": self.config.id_instrument, "night": night, "flat": self.config.id_flat, }, "trace": { "instrument": self.config.id_instrument, "night": night, "trace": self.config.id_orders, }, "scatter": { "instrument": self.config.id_instrument, "night": night, "scatter": self.config.id_scatter, }, "curvature": { "instrument": self.config.id_instrument, "night": night, "curvature": self.config.id_curvature, }, "wavecal_master": { "instrument": self.config.id_instrument, "night": night, "wave": self.config.id_wave, }, "freq_comb_master": { "instrument": self.config.id_instrument, "night": night, "comb": self.config.id_comb, }, "science": { "instrument": self.config.id_instrument, "night": night, "target": target, "spec": self.config.id_spec, }, } # Add channel filter if this instrument has separate files per channel if channel is not None and self.config.kw_channel is not None: id_channel = self.config.id_channel channels = self.config.channels channel_id = ( id_channel[channels.index(channel)] if channel in channels else channel ) for key in expectations: expectations[key]["channel"] = channel_id return expectations
[docs] def populate_filters(self, files): """Extract values from the fits headers and store them in self.filters Parameters ---------- files : list(str) list of fits files Returns ------- filters: list(Filter) list of populated filters (identical to self.filters) """ # Empty filters for _, fil in self.filters.items(): fil.clear() for f in tqdm(files): with fits.open(f) as hdu: h = hdu[0].header for _, fil in self.filters.items(): fil.collect(h) return self.filters
[docs] def apply_filters(self, files, expected, steps=None): """ Determine the relevant files for a given set of expected values. Parameters ---------- files : list(files) list if fits files expected : dict dictionary with expected header values for each reduction step steps : list, optional list of steps that will be run. If provided, warnings about missing files are only shown for these steps. Returns ------- files: list((dict, dict)) list of files. The first element of each tuple is the used setting, and the second are the files for each step. """ # Fill the filters with header information self.populate_filters(files) # Use the header information determined in populate filters # to find potential science and calibration files in the list of files # result = {step : [ {setting : value}, [files] ] } result = {} for step, values in expected.items(): result[step] = [] data = {} for name, value in values.items(): # For 'night', don't filter during classification # - get all nights so find_closest can work if name == self.night: value = None if isinstance(value, list): for v in value: data[name] = self.filters[name].classify(v) if len(data[name]) > 0: break else: data[name] = self.filters[name].classify(value) # Get all combinations of possible filter values # e.g. if several nights are allowed for thingy in product(*data.values()): mask = np.copy(thingy[0][1]) for i in range(1, len(thingy)): mask &= thingy[i][1] if np.count_nonzero(mask) == 0: continue d = {k: v[0] for k, v in zip(values.keys(), thingy, strict=False)} f = files[mask] result[step].append((d, f)) # files = [{setting: value}, {step: files}] files = [] settings = {} for shared in self.shared: # Check if user specified a value for this shared parameter sample_expected = expected.get(self.science, {}).get(shared) if sample_expected is not None and sample_expected != "": # User specified a value - use only matching values from data keys = self.filters[shared].classify(sample_expected) keys = [k for k, _ in keys if k is not None] else: # No filter - use all unique values keys = [k for k in set(self.filters[shared].data) if k is not None] settings[shared] = keys values = [settings[k] for k in self.shared] for setting in product(*values): setting = dict(zip(self.shared, setting, strict=False)) night = setting[self.night] f = {} # For each step look for files with matching settings for step, step_data in result.items(): f[step] = [] for step_key, step_files in step_data: match = [ setting[shared] == step_key[shared] for shared in self.shared if shared in step_key.keys() ] if all(match): f[step] = step_files break # If no matching files are found ... if len(f[step]) == 0: if step not in self.find_closest: # Show a warning (only for requested steps) if steps is None or step in steps: logger.warning( "Could not find any files for step '%s' with settings %s, sharing parameters %s", step, setting, self.shared, ) else: # Or find the closest night instead j = None for i, (step_key, _) in enumerate(step_data): match = [ setting[shared] == step_key[shared] for shared in self.shared if shared in step_key.keys() and shared != self.night ] if all(match): if j is None: j = i else: diff_old = abs(step_data[j][0][self.night] - night) diff_new = abs(step_data[i][0][self.night] - night) if diff_new < diff_old: j = i if j is None: # We still dont find any files (only warn for requested steps) if steps is None or step in steps: logger.warning( "Could not find any files for step '%s' in any night with settings %s, sharing parameters %s", step, setting, self.shared, ) else: # We found files in a close night closest_key, closest_files = step_data[j] logger.warning( "Using '%s' files from night %s for observations of night %s", step, night, closest_key["night"], ) f[step] = closest_files if any(len(a) > 0 for a in f.values()): files.append((setting, f)) if len(files) == 0: logger.warning( "No %s files found matching the expected values %s", self.science, expected[self.science], ) return files
[docs] def sort_files(self, input_dir, target, night, *args, steps=None, **kwargs): """ Sort a set of fits files into different categories types are: bias, flat, wavecal, orderdef, spec Parameters ---------- input_dir : str input directory containing the files to sort target : str name of the target as in the fits headers night : str observation night, possibly with wildcards channel : str instrument channel steps : list, optional list of steps that will be run. If provided, warnings about missing files are only shown for these steps. Returns ------- files_per_night : list[dict{str:dict{str:list[str]}}] a list of file sets, one entry per night, where each night consists of a dictionary with one entry per setting, each fileset has five lists of filenames: "bias", "flat", "order", "wave", "spec", organised in another dict nights_out : list[datetime] a list of observation times, same order as files_per_night """ input_dir = input_dir.format( **kwargs, target=target, night=night, instrument=self.name ) files = self.find_files(input_dir) ev = self.get_expected_values(target, night, *args, **kwargs) files = self.apply_filters(files, ev, steps=steps) return files
[docs] def discover_channels(self, input_dir): """Discover available channels from files in the input directory. Override in subclasses for instruments that require channel specification. The base implementation returns [None] which means no channel filtering. Parameters ---------- input_dir : str Directory containing input files Returns ------- channels : list List of channel identifiers found in the data """ return [None]
[docs] def get_wavecal_filename(self, header, channel, **kwargs): """Get the filename of the pre-existing wavelength solution for the current setting Parameters ---------- header : fits.header, dict header of the wavelength calibration file channel : str instrument channel Returns ------- filename : str name of the wavelength solution file """ specifier = header.get(self.config.wavecal_specifier or "", "") cwd = os.path.dirname(__file__) fname = f"wavecal_{channel}_{specifier}.npz" fname = os.path.join(cwd, self.name.upper(), fname) return fname
[docs] def get_supported_channels(self): # Single-channel instruments with no `channels:` in config.yaml # iterate once with channel=None. return self.channels or [None]
[docs] def get_settings_fallbacks(self, channel): """Return channel names to try when looking up settings files. Searched in order: most specific first, least specific last. The base settings.json is always the final fallback (handled by the caller). Override in subclasses for instruments with composite channel names. """ return [channel] if channel else []
[docs] def get_mask_filename(self, channel, **kwargs): c = channel.lower() if channel else "" fname = f"mask_{c}.fits.gz" if c else "mask.fits.gz" cwd = os.path.dirname(__file__) fname = os.path.join(cwd, self.name.upper(), fname) return fname
[docs] def get_wavelength_range(self, header, channel, **kwargs): return self.get("wavelength_range", header, channel)
[docs] class COMMON(Instrument):
[docs] def load_info(self): """Load the default/common instrument config from defaults/config.yaml""" this = os.path.dirname(__file__) yaml_fname = os.path.join(this, "defaults", "config.yaml") with open(yaml_fname) as f: info = yaml.safe_load(f) config = InstrumentConfig(**info) return config, info
[docs] def create_custom_instrument( name, extension=0, info=None, mask_file=None, wavecal_file=None, **overrides ): """Create a custom instrument for reducing data from unsupported spectrographs. Parameters ---------- name : str Instrument name (used for output file naming) extension : int or str FITS extension to read data from (default: 0) info : dict or str, optional Instrument properties as a dict, or path to a config file (JSON/YAML). If None, uses sensible defaults. Dict values can be literal numbers or FITS header keyword strings (looked up at runtime). mask_file : str, optional Path to bad pixel mask FITS file wavecal_file : str, optional Path to wavelength calibration file **overrides Additional instrument properties, e.g. ``gain=1.1``, ``readnoise=5``, ``orientation=4``. These override values from the info dict. See ``instruments/defaults/config.yaml`` for available keys. Returns ------- Instrument Configured instrument instance """ class CUSTOM(Instrument): def __init__(self): super().__init__() self.name = name def load_info(self): base_info = COMMON().info.copy() if info is not None: if isinstance(info, dict): base_info.update(info) elif isinstance(info, str): with open(info) as f: if info.endswith((".yaml", ".yml")): base_info.update(yaml.safe_load(f)) else: base_info.update(json.load(f)) base_info["extension"] = extension base_info.update(overrides) config = InstrumentConfig(**base_info) return config, base_info def get_extension(self, header, channel): return extension def get_mask_filename(self, channel, **kwargs): return mask_file def get_wavecal_filename(self, header, channel, **kwargs): return wavecal_file return CUSTOM()