import numpy as np
from scipy.interpolate import interp1d
from tqdm import tqdm
from . import util
from .extract import correct_for_curvature
[docs]
def rectify_image(img, traces, extraction_height, trace_range):
"""Rectify image by extracting and straightening each trace.
Parameters
----------
img : array
Input image
traces : list[Trace]
Trace objects with pos, column_range, and optional slit curvature
extraction_height : float
Extraction height (fraction if < 3, else pixels)
trace_range : tuple
(start, end) indices of traces to process
Returns
-------
images : dict
Rectified images keyed by trace index
column_range : array
Column ranges for each trace
extraction_height : array
Extraction heights for each trace
"""
nrow, ncol = img.shape
x = np.arange(ncol)
# Apply trace_range slicing
traces = traces[trace_range[0] : trace_range[1]]
ntrace = len(traces)
# Build column_range array from traces
column_range = np.array([t.column_range for t in traces])
# Compute extraction height in pixels if fractional
xwd = extraction_height
if np.isscalar(xwd) and xwd < 3:
x_mid = ncol // 2
y_mids = np.array([np.polyval(t.pos, x_mid) for t in traces])
if len(y_mids) > 1:
spacing = np.median(np.abs(np.diff(np.sort(y_mids))))
xwd = int(xwd * spacing)
else:
xwd = 10
if np.isscalar(xwd):
xwd_arr = np.full(ntrace, int(xwd))
else:
xwd_arr = np.asarray(xwd, dtype=int)
images = {}
for i, trace in enumerate(tqdm(traces, desc="Trace")):
x_left_lim, x_right_lim = trace.column_range
# Rectify the image, i.e. remove the shape of the trace
# Then the center of the trace is within one pixel variations
ycen = np.polyval(trace.pos, x).astype(int)
half = xwd_arr[i] // 2
yb = ycen - half
yt = yb + xwd_arr[i] - 1
index = util.make_index(yb, yt, x_left_lim, x_right_lim)
img_order = img[index]
# Correct for curvature using trace.slit if available
# slit[1, :] = linear term coeffs, slit[2, :] = quadratic term coeffs
if trace.slit is not None and trace.slit.shape[0] > 2:
# Evaluate curvature polynomials at each column
x_range = np.arange(x_left_lim, x_right_lim)
p1 = np.polyval(trace.slit[1, :], x_range)
p2 = np.polyval(trace.slit[2, :], x_range)
img_order = correct_for_curvature(img_order, p1, p2, xwd_arr[i])
images[i] = img_order
return images, column_range, xwd_arr
[docs]
def merge_images(images, wave, column_range, extraction_height):
x_total = sum(img.shape[1] for img in images.values())
y_max = max(*[img.shape[0] for img in images.values()])
y_mid = y_max // 2
combined_img = np.zeros((y_max, x_total))
wavelength = np.zeros(x_total)
idx = 0
x_low = 0
for iord0, iord1 in zip(range(len(wave) - 1), range(1, len(wave)), strict=False):
img0 = images[iord0]
img1 = images[iord1]
xwd0, xwd1 = extraction_height[iord0], extraction_height[iord1]
half0, half1 = xwd0 // 2, xwd1 // 2
y0_low = y_mid - half0
y0_high = y_mid + half0 + 1
y1_low = y_mid - half1
y1_high = y_mid + half1 + 1
# Calculate Overlap
cr0, cr1 = column_range[iord0], column_range[iord1]
w0 = wave[iord0][cr0[0] : cr0[1]]
w1 = wave[iord1][cr1[0] : cr1[1]]
i0 = np.ma.where((w0 >= np.ma.min(w1)) & (w0 <= np.ma.max(w1)))
i1 = np.ma.where((w1 >= np.ma.min(w0)) & (w1 <= np.ma.max(w0)))
if i0[0].size > 0 and i1[0].size > 0:
# The non overlapping part is just the image
x_high = i0[0].min()
combined_img[y0_low:y0_high, idx : idx + x_high - x_low] = img0[
:, x_low:x_high
]
wavelength[idx : idx + x_high - x_low] = w0[x_low:x_high]
# for the overlap region use a common wavelength grid
n_points = (len(i0[0]) + len(i1[0])) // 2
w_common = np.geomspace(w0[i0][0], w1[i1][-1], num=n_points)
img0_common = interp1d(
w0[i0], img0[:, i0[0]], kind="linear", fill_value="extrapolate"
)(w_common)
img1_common = interp1d(
w1[i1], img1[:, i1[0]], kind="linear", fill_value="extrapolate"
)(w_common)
# And then simply take the average between the two
counter_common = np.zeros((y_max, n_points), dtype=int)
img_common = np.zeros((y_max, n_points))
img_common[y0_low:y0_high] += img0_common
counter_common[y0_low:y0_high] += 1
img_common[y1_low:y1_high] += img1_common
counter_common[y1_low:y1_high] += 1
counter_common[counter_common == 0] = 1
img_common /= counter_common
combined_img[:, idx + x_low : idx + x_low + n_points] = img_common
wavelength[idx + x_low : idx + x_low + n_points] = w_common
idx += x_low + n_points
x_low = i1[0].max()
else:
x_high = img0.shape[1]
combined_img[y0_low:y0_high, idx : idx + x_high] = img0
wavelength[idx : idx + x_high] = w0
idx += x_high
x_low = 0
img0 = images[len(wave) - 1]
y0 = img0.shape[0]
y0_low = (y_max - y0) // 2
y0_high = y0 + y0_low
cr0 = column_range[iord0]
w0 = wave[iord0][cr0[0] : cr0[1]]
x_high = img0.shape[1]
combined_img[y0_low:y0_high, idx : idx + x_high - x_low] = img0[:, x_low:x_high]
wavelength[idx : idx + x_high - x_low] = w0[x_low:x_high]
idx += x_high - x_low
combined_img = combined_img[:, :idx]
wavelength = wavelength[:idx]
return wavelength, combined_img