Source code for pyreduce.rectify

import numpy as np
from scipy.interpolate import interp1d
from tqdm import tqdm

from . import util
from .extract import correct_for_curvature


[docs] def rectify_image(img, traces, extraction_height, trace_range): """Rectify image by extracting and straightening each trace. Parameters ---------- img : array Input image traces : list[Trace] Trace objects with pos, column_range, and optional slit curvature extraction_height : float Extraction height (fraction if < 3, else pixels) trace_range : tuple (start, end) indices of traces to process Returns ------- images : dict Rectified images keyed by trace index column_range : array Column ranges for each trace extraction_height : array Extraction heights for each trace """ nrow, ncol = img.shape x = np.arange(ncol) # Apply trace_range slicing traces = traces[trace_range[0] : trace_range[1]] ntrace = len(traces) # Build column_range array from traces column_range = np.array([t.column_range for t in traces]) # Compute extraction height in pixels if fractional xwd = extraction_height if np.isscalar(xwd) and xwd < 3: x_mid = ncol // 2 y_mids = np.array([np.polyval(t.pos, x_mid) for t in traces]) if len(y_mids) > 1: spacing = np.median(np.abs(np.diff(np.sort(y_mids)))) xwd = int(xwd * spacing) else: xwd = 10 if np.isscalar(xwd): xwd_arr = np.full(ntrace, int(xwd)) else: xwd_arr = np.asarray(xwd, dtype=int) images = {} for i, trace in enumerate(tqdm(traces, desc="Trace")): x_left_lim, x_right_lim = trace.column_range # Rectify the image, i.e. remove the shape of the trace # Then the center of the trace is within one pixel variations ycen = np.polyval(trace.pos, x).astype(int) half = xwd_arr[i] // 2 yb = ycen - half yt = yb + xwd_arr[i] - 1 index = util.make_index(yb, yt, x_left_lim, x_right_lim) img_order = img[index] # Correct for curvature using trace.slit if available # slit[1, :] = linear term coeffs, slit[2, :] = quadratic term coeffs if trace.slit is not None and trace.slit.shape[0] > 2: # Evaluate curvature polynomials at each column x_range = np.arange(x_left_lim, x_right_lim) p1 = np.polyval(trace.slit[1, :], x_range) p2 = np.polyval(trace.slit[2, :], x_range) img_order = correct_for_curvature(img_order, p1, p2, xwd_arr[i]) images[i] = img_order return images, column_range, xwd_arr
[docs] def merge_images(images, wave, column_range, extraction_height): x_total = sum(img.shape[1] for img in images.values()) y_max = max(*[img.shape[0] for img in images.values()]) y_mid = y_max // 2 combined_img = np.zeros((y_max, x_total)) wavelength = np.zeros(x_total) idx = 0 x_low = 0 for iord0, iord1 in zip(range(len(wave) - 1), range(1, len(wave)), strict=False): img0 = images[iord0] img1 = images[iord1] xwd0, xwd1 = extraction_height[iord0], extraction_height[iord1] half0, half1 = xwd0 // 2, xwd1 // 2 y0_low = y_mid - half0 y0_high = y_mid + half0 + 1 y1_low = y_mid - half1 y1_high = y_mid + half1 + 1 # Calculate Overlap cr0, cr1 = column_range[iord0], column_range[iord1] w0 = wave[iord0][cr0[0] : cr0[1]] w1 = wave[iord1][cr1[0] : cr1[1]] i0 = np.ma.where((w0 >= np.ma.min(w1)) & (w0 <= np.ma.max(w1))) i1 = np.ma.where((w1 >= np.ma.min(w0)) & (w1 <= np.ma.max(w0))) if i0[0].size > 0 and i1[0].size > 0: # The non overlapping part is just the image x_high = i0[0].min() combined_img[y0_low:y0_high, idx : idx + x_high - x_low] = img0[ :, x_low:x_high ] wavelength[idx : idx + x_high - x_low] = w0[x_low:x_high] # for the overlap region use a common wavelength grid n_points = (len(i0[0]) + len(i1[0])) // 2 w_common = np.geomspace(w0[i0][0], w1[i1][-1], num=n_points) img0_common = interp1d( w0[i0], img0[:, i0[0]], kind="linear", fill_value="extrapolate" )(w_common) img1_common = interp1d( w1[i1], img1[:, i1[0]], kind="linear", fill_value="extrapolate" )(w_common) # And then simply take the average between the two counter_common = np.zeros((y_max, n_points), dtype=int) img_common = np.zeros((y_max, n_points)) img_common[y0_low:y0_high] += img0_common counter_common[y0_low:y0_high] += 1 img_common[y1_low:y1_high] += img1_common counter_common[y1_low:y1_high] += 1 counter_common[counter_common == 0] = 1 img_common /= counter_common combined_img[:, idx + x_low : idx + x_low + n_points] = img_common wavelength[idx + x_low : idx + x_low + n_points] = w_common idx += x_low + n_points x_low = i1[0].max() else: x_high = img0.shape[1] combined_img[y0_low:y0_high, idx : idx + x_high] = img0 wavelength[idx : idx + x_high] = w0 idx += x_high x_low = 0 img0 = images[len(wave) - 1] y0 = img0.shape[0] y0_low = (y_max - y0) // 2 y0_high = y0 + y0_low cr0 = column_range[iord0] w0 = wave[iord0][cr0[0] : cr0[1]] x_high = img0.shape[1] combined_img[y0_low:y0_high, idx : idx + x_high - x_low] = img0[:, x_low:x_high] wavelength[idx : idx + x_high - x_low] = w0[x_low:x_high] idx += x_high - x_low combined_img = combined_img[:, :idx] wavelength = wavelength[:idx] return wavelength, combined_img