Source code for pyreduce.continuum_normalization

# -*- coding: utf-8 -*-
Find the continuum level

Currently only splices orders together
First guess of the continuum is provided by the flat field
import logging
from itertools import chain

import matplotlib.pyplot as plt
import numpy as np

from . import util

logger = logging.getLogger(__name__)

# np.seterr("raise")

[docs]def splice_orders(spec, wave, cont, sigm, scaling=True, plot=False, plot_title=None): """ Splice orders together so that they form a continous spectrum This is achieved by linearly combining the overlaping regions Parameters ---------- spec : array[nord, ncol] Spectrum to splice, with seperate orders wave : array[nord, ncol] Wavelength solution for each point cont : array[nord, ncol] Continuum, blaze function will do fine as well sigm : array[nord, ncol] Errors on the spectrum scaling : bool, optional If true, the spectrum/continuum will be scaled to 1 (default: False) plot : bool, optional If true, will plot the spliced spectrum (default: False) Raises ------ NotImplementedError If neighbouring orders dont overlap Returns ------- spec, wave, cont, sigm : array[nord, ncol] spliced spectrum """ nord, _ = spec.shape # Number of sp. orders, Order length in pixels if cont is None: cont = np.ones_like(spec) # Just to be extra safe that they are all the same mask = ( | ( == 0) | ( == 0) ) spec =, mask=mask) wave =, mask=mask) cont =, mask=mask) sigm =, mask=mask) if scaling: # Scale everything to roughly the same size, around spec/blaze = 1 scale = / cont, axis=1) cont *= scale[:, None] if plot: # pragma: no cover plt.subplot(411) if plot_title is not None: plt.suptitle(plot_title) plt.title("Before") for i in range(spec.shape[0]): plt.plot(wave[i], spec[i] / cont[i]) plt.ylim([0, 2]) plt.subplot(412) plt.title("Before Error") for i in range(spec.shape[0]): plt.plot(wave[i], sigm[i] / cont[i]) plt.ylim((0,[i] / cont[i]) * 2)) # Order with largest signal, everything is scaled relative to this order iord0 = np.argmax( / cont, axis=1)) # Loop from iord0 outwards, first to the top, then to the bottom tmp0 = chain(range(iord0, 0, -1), range(iord0, nord - 1)) tmp1 = chain(range(iord0 - 1, -1, -1), range(iord0 + 1, nord)) # Looping over order pairs for iord0, iord1 in zip(tmp0, tmp1): # Get data for current order # Note that those are just references to parts of the original data # any changes will also affect spec, wave, cont, and sigm s0, s1 = spec[iord0], spec[iord1] w0, w1 = wave[iord0], wave[iord1] c0, c1 = cont[iord0], cont[iord1] u0, u1 = sigm[iord0], sigm[iord1] # Calculate Overlap i0 = >= & (w0 <= i1 = >= & (w1 <= # Orders overlap if i0[0].size > 0 and i1[0].size > 0: # Interpolate the overlapping region onto the wavelength grid of the other order tmpS0 = util.bezier_interp(w1, s1, w0[i0]) tmpB0 = util.bezier_interp(w1, c1, w0[i0]) tmpU0 = util.bezier_interp(w1, u1, w0[i0]) tmpS1 = util.bezier_interp(w0, s0, w1[i1]) tmpB1 = util.bezier_interp(w0, c0, w1[i1]) tmpU1 = util.bezier_interp(w0, u0, w1[i1]) # Combine the two orders weighted by the relative error wgt0 =[c0[i0].data / u0[i0].data, tmpB0 / tmpU0]) ** 2 wgt1 =[c1[i1].data / u1[i1].data, tmpB1 / tmpU1]) ** 2 s0[i0], utmp =[s0[i0], tmpS0]), axis=0, weights=wgt0, returned=True ) c0[i0] =[c0[i0], tmpB0], axis=0, weights=wgt0) u0[i0] = c0[i0] * utmp ** -0.5 s1[i1], utmp =[s1[i1], tmpS1]), axis=0, weights=wgt1, returned=True ) c1[i1] =[c1[i1], tmpB1], axis=0, weights=wgt1) u1[i1] = c1[i1] * utmp ** -0.5 else: # pragma: no cover # TODO: Orders dont overlap continue if plot: # pragma: no cover plt.subplot(413) plt.title("After") for i in range(nord): plt.plot(wave[i], spec[i] / cont[i], label="order=%i" % i) plt.ylim((0, 2)) plt.subplot(414) plt.title("Error") for i in range(nord): plt.plot(wave[i], sigm[i] / cont[i], label="order=%i" % i) plt.ylim((0,[i] / cont[i]) * 2)) return spec, wave, cont, sigm
[docs]class Plot_Normalization: # pragma: no cover def __init__(self, wsort, sB, new_wave, contB, iteration=0, title=None): plt.ion() self.fig = plt.figure() self.title = title suptitle = f"Iteration: {iteration}" if self.title is not None: suptitle = f"{self.title}\n{suptitle}" self.fig.suptitle(suptitle) = self.fig.add_subplot(111) self.line1 =, sB, label="Spectrum")[0] self.line2 =, contB, label="Continuum Fit")[0] plt.legend()
[docs] def plot(self, wsort, sB, new_wave, contB, iteration): suptitle = f"Iteration: {iteration}" if self.title is not None: suptitle = f"{self.title}\n{suptitle}" self.fig.suptitle(suptitle) self.line1.set_xdata(wsort) self.line1.set_ydata(sB) self.line2.set_xdata(new_wave) self.line2.set_ydata(contB) self.fig.canvas.draw() self.fig.canvas.flush_events()
[docs] def close(self): plt.ioff() plt.close()
[docs]def continuum_normalize( spec, wave, cont, sigm, iterations=10, smooth_initial=1e5, smooth_final=5e6, scale_vert=1, plot=True, plot_title=None, ): """Fit a continuum to a spectrum by slowly approaching it from the top. We exploit here that the continuum varies only on large wavelength scales, while individual lines act on much smaller scales TODO automatically find good parameters for smooth_initial and smooth_final TODO give variables better names Parameters ---------- spec : masked array of shape (nord, ncol) Observed input spectrum, masked values describe column ranges wave : masked array of shape (nord, ncol) Wavelength solution of the spectrum cont : masked array of shape (nord, ncol) Initial continuum guess, for example based on the blaze sigm : masked array of shape (nord, ncol) Uncertainties of the spectrum iterations : int, optional Number of iterations of the algorithm, note that runtime roughly scales with the number of iterations squared (default: 10) smooth_initial : float, optional Smoothing parameter in the initial runs, usually smaller than smooth_final (default: 1e5) smooth_final : float, optional Smoothing parameter of the final run (default: 5e6) scale_vert : float, optional Vertical scale of the spectrum. Usually 1 if a previous normalization exists (default: 1) plot : bool, optional Wether to plot the current status and results or not (default: True) Returns ------- cont : masked array of shape (nord, ncol) New continuum """ nord, ncol = spec.shape par2 = 1e-4 par4 = 0.01 * (1 - np.clip(2, None, 1 / np.sqrt( b = np.clip(cont, 1, None) mask = for i in range(nord): b[i, mask[i]] = util.middle(b[i, mask[i]], 1) cont = b # Create new equispaced wavelength grid tmp = wave.compressed() wmin = np.min(tmp) wmax = np.max(tmp) dwave = np.abs(tmp[tmp.size // 2] - tmp[tmp.size // 2 - 1]) * 0.5 nwave = np.ceil((wmax - wmin) / dwave) + 1 new_wave = np.linspace(wmin, wmax, int(nwave), endpoint=True) # Combine all orders into one big spectrum, sorted by wavelength wsort, j, index = np.unique(tmp, return_index=True, return_inverse=True) sB = (spec / cont).compressed()[j] # Get initial weights for each point weight = util.middle(sB, 0.5, x=wsort - wmin) weight = weight / util.middle(weight, 3 * smooth_initial) + np.concatenate( ([0], 2 * weight[1:-1] - weight[0:-2] - weight[2:], [0]) ) weight = np.clip(weight, 0, None) # TODO for some reason the interpolation messes up, use linear instead for now # weight = util.safe_interpolation(wsort, weight, new_wave) weight = np.interp(new_wave, wsort, weight) weight /= np.max(weight) # Interpolate Spectrum onto the new grid # ssB = util.safe_interpolation(wsort, sB, new_wave) ssB = np.interp(new_wave, wsort, sB) # Keep the scale of the continuum bbb = util.middle(cont.compressed()[j], 1) contB = np.ones_like(ssB) if plot: # pragma: no cover p = Plot_Normalization(wsort, sB, new_wave, contB, 0, title=plot_title) try: for i in range(iterations): # Find new approximation of the top, smoothed by some parameter c = ssB / contB for _ in range(iterations): _c = c, smooth_initial, eps=par2, weight=weight, lambda2=smooth_final ) c = np.clip(_c, c, None) c = ( c, smooth_initial, eps=par4, weight=weight, lambda2=smooth_final ) * contB ) # Scale it and update the weights of each point contB = c * scale_vert contB = util.middle(contB, 1) weight = np.clip(ssB / contB, None, contB / np.clip(ssB, 1, None)) # Plot the intermediate results if plot: # pragma: no cover p.plot(wsort, sB, new_wave, contB, i) except ValueError: logger.error("Continuum fitting aborted") finally: if plot: # pragma: no cover p.close() # Calculate the new continuum from intermediate values # new_cont = util.safe_interpolation(new_wave, contB, wsort) new_cont = np.interp(wsort, new_wave, contB) mask = cont[~mask] = (new_cont * bbb)[index] # Final output plot if plot: # pragma: no cover plt.plot(wave.ravel(), spec.ravel(), label="spec") plt.plot(wave.ravel(), cont.ravel(), label="cont") plt.legend(loc="best") if plot_title is not None: plt.title(plot_title) plt.xlabel("Wavelength [A]") plt.ylabel("Flux") return cont