"""
Find clusters of pixels with signal and fit polynomial traces.
Note on terminology:
- "trace": A single polynomial fit to a cluster of pixels (e.g., one fiber)
- "spectral order": A group of traces at similar wavelengths (e.g., all fibers in one echelle order)
The main function `trace()` detects and fits individual traces, returning Trace objects.
Use `group_fibers()` to merge traces into fiber groups according to instrument config.
"""
import logging
import re
from functools import cmp_to_key
from itertools import combinations
import matplotlib.pyplot as plt
import numpy as np
def _natural_sort_key(s):
"""Sort key for natural ordering (e.g., bundle_2 before bundle_10)."""
return [int(c) if c.isdigit() else c.lower() for c in re.split(r"(\d+)", s)]
from astropy.convolution import Gaussian2DKernel, interpolate_replace_nans
from numpy.polynomial.polynomial import Polynomial
from scipy.ndimage import binary_closing, binary_opening, label
from scipy.ndimage.filters import gaussian_filter1d, median_filter, uniform_filter1d
from scipy.signal import find_peaks, peak_widths
from scipy.sparse import diags
from scipy.sparse.linalg import spsolve
from . import util
from .trace_model import Trace as TraceData
logger = logging.getLogger(__name__)
def _find_beam_pairs_dp(y_positions, fibers_per_order):
"""Find optimal pairing of traces into beam pairs using gap analysis.
For dual-beam instruments (fibers_per_order=2), beam pairs have smaller
gaps than inter-order gaps. This function uses dynamic programming to
find the pairing that maximizes the number of paired traces while only
allowing pairs whose gap is below an automatically computed threshold.
Parameters
----------
y_positions : array
Sorted y-positions of traces at detector center.
fibers_per_order : int
Number of fibers (beams) per order (typically 2).
Returns
-------
list of tuples
Each tuple contains `fibers_per_order` trace indices forming one order.
"""
n = len(y_positions)
if n < fibers_per_order:
return []
# Only implemented for fibers_per_order=2
if fibers_per_order != 2:
# Fallback to simple sequential grouping
groups = []
for i in range(0, n - n % fibers_per_order, fibers_per_order):
groups.append(tuple(range(i, i + fibers_per_order)))
return groups
gaps = np.diff(y_positions)
# Find threshold separating beam-pair gaps from inter-order gaps.
# Use Otsu's method: find the threshold that minimizes the weighted
# intra-class variance of the two groups (beam-pair vs inter-order).
sorted_gaps = np.sort(gaps)
best_threshold = np.median(sorted_gaps)
best_variance = np.inf
for i in range(1, len(sorted_gaps)):
if sorted_gaps[i] == sorted_gaps[i - 1]:
continue
candidate = (sorted_gaps[i - 1] + sorted_gaps[i]) / 2
lo_group = sorted_gaps[:i]
hi_group = sorted_gaps[i:]
w0 = len(lo_group) / len(sorted_gaps)
w1 = len(hi_group) / len(sorted_gaps)
weighted_var = w0 * np.var(lo_group) + w1 * np.var(hi_group)
if weighted_var < best_variance:
best_variance = weighted_var
best_threshold = candidate
threshold = best_threshold
n_beam = int(np.sum(gaps <= threshold))
n_inter = int(np.sum(gaps > threshold))
logger.info(
"Beam-pair gap threshold: %.1f px (%d beam gaps, %d inter-order gaps)",
threshold,
n_beam,
n_inter,
)
# DP: find maximum number of paired traces using only small-gap pairs.
# dp[i] = (num_paired, total_gap) for traces 0..i-1
# At each position, either skip a trace or pair it with the previous one.
dp_count = np.zeros(n + 1, dtype=int)
dp_gap = np.zeros(n + 1)
choice = np.zeros(n + 1, dtype=int) # 0=skip, 1=pair
for i in range(2, n + 1):
# Option 1: skip trace i-1
dp_count[i] = dp_count[i - 1]
dp_gap[i] = dp_gap[i - 1]
choice[i] = 0
# Option 2: pair traces i-2 and i-1
if gaps[i - 2] <= threshold:
new_count = dp_count[i - 2] + 2
new_gap = dp_gap[i - 2] + gaps[i - 2]
if new_count > dp_count[i] or (
new_count == dp_count[i] and new_gap < dp_gap[i]
):
dp_count[i] = new_count
dp_gap[i] = new_gap
choice[i] = 1
# Backtrack to find pairs
pairs = []
i = n
while i >= 2:
if choice[i] == 1:
pairs.append((i - 2, i - 1))
i -= 2
else:
i -= 1
pairs.reverse()
logger.info(
"DP pairing: %d pairs from %d traces (%d unpaired)",
len(pairs),
n,
n - 2 * len(pairs),
)
return pairs
def _assign_order_and_fiber_inplace(
traces: list,
order_centers: dict[int, float] | None,
ncol: int,
fibers_per_order: int | None = None,
bundle_centers: dict[int, float] | None = None,
) -> None:
"""Assign m (order number), bundle, and fiber_idx to Trace objects.
Parameters
----------
traces : list[Trace]
Trace objects (modified in place)
order_centers : dict[int, float] | None
Order number -> y-position mapping. If None, m stays None
(unless fibers_per_order is set for auto-pairing or bundle_centers
is provided alone, in which case the sequential-m fallback is
skipped because each trace is one fiber within a bundle, not its
own order).
ncol : int
Number of columns in detector
fibers_per_order : int or None
If set and order_centers is None, group every N consecutive traces
(sorted by y-position) into the same order with sequential fiber_idx.
bundle_centers : dict[int, float] | None
Bundle id -> y-position mapping. When set, each trace is matched
to the nearest bundle center; t.bundle is populated. fiber_idx is
then re-assigned within each (m, bundle).
"""
if not traces:
return
x_center = ncol // 2
y_positions = [t.y_at_x(x_center) for t in traces]
if order_centers is not None:
# Match traces to known order centers
order_nums = np.array(list(order_centers.keys()))
y_centers = np.array([order_centers[m] for m in order_nums])
for i, y in enumerate(y_positions):
distances = np.abs(y - y_centers)
closest_idx = np.argmin(distances)
traces[i].m = int(order_nums[closest_idx])
# Group by m to assign fiber_idx within each order
from collections import defaultdict
traces_by_m = defaultdict(list)
for i, t in enumerate(traces):
traces_by_m[t.m].append((i, y_positions[i]))
# Assign fiber_idx: sort by y within each order, then 1, 2, 3...
for _m, trace_list in traces_by_m.items():
trace_list.sort(key=lambda x: x[1])
for fiber_idx, (trace_idx, _y) in enumerate(trace_list, start=1):
traces[trace_idx].fiber_idx = fiber_idx
elif fibers_per_order is not None and fibers_per_order > 1:
# Auto-pair traces using gap analysis (beam-pair gaps are smaller
# than inter-order gaps). Uses DP to find optimal pairing.
traces.sort(key=lambda t: t.y_at_x(x_center))
y_pos = np.array([t.y_at_x(x_center) for t in traces])
pair_indices = _find_beam_pairs_dp(y_pos, fibers_per_order)
# Assign m and fiber_idx from DP result
order_num = 0
paired_set = set()
for group in pair_indices:
for fiber_idx, trace_idx in enumerate(group, start=1):
traces[trace_idx].m = order_num
traces[trace_idx].fiber_idx = fiber_idx
paired_set.add(trace_idx)
order_num += 1
# Drop unpaired traces
n_dropped = len(traces) - len(paired_set)
if n_dropped > 0:
logger.warning(
"%d traces could not be paired and will be dropped", n_dropped
)
traces[:] = [t for i, t in enumerate(traces) if i in paired_set]
logger.info(
"Auto-paired %d traces into %d orders (fibers_per_order=%d)",
len(traces),
order_num,
fibers_per_order,
)
elif bundle_centers is None:
# No order_centers, no fibers_per_order, no bundle_centers: assign
# sequential m and fiber_idx=1 to all traces (one fiber per order).
traces.sort(key=lambda t: t.y_at_x(x_center))
for i, t in enumerate(traces):
t.m = i
t.fiber_idx = 1
# else: bundle_centers handles assignment below; m stays None.
# Bundle assignment: independent of m, populates t.bundle and
# re-derives fiber_idx within each (m, bundle).
if bundle_centers is not None:
from collections import defaultdict
bundle_ids = np.array(list(bundle_centers.keys()))
bundle_ys = np.array([bundle_centers[b] for b in bundle_ids])
for i, y in enumerate(y_positions):
closest = int(np.argmin(np.abs(y - bundle_ys)))
traces[i].bundle = int(bundle_ids[closest])
traces_by_mb = defaultdict(list)
for i, t in enumerate(traces):
traces_by_mb[(t.m, t.bundle)].append((i, y_positions[i]))
for _key, trace_list in traces_by_mb.items():
trace_list.sort(key=lambda x: x[1])
for fiber_idx, (trace_idx, _y) in enumerate(trace_list, start=1):
traces[trace_idx].fiber_idx = fiber_idx
# Sort traces by (m descending, bundle, fiber_idx)
def sort_key(t):
m_val = t.m if t.m is not None else float("inf")
b_val = t.bundle if t.bundle is not None else 0
return (-m_val, b_val, t.fiber_idx or 0)
traces.sort(key=sort_key)
logger.info("Assigned order/fiber to %d traces", len(traces))
if order_centers is not None:
unique_m = {t.m for t in traces if t.m is not None}
logger.info(" Order numbers (m): %s", sorted(unique_m, reverse=True))
if bundle_centers is not None:
unique_b = {t.bundle for t in traces if t.bundle is not None}
logger.info(
" Bundle ids: %d unique (range %d..%d)",
len(unique_b),
min(unique_b),
max(unique_b),
)
def _compute_heights_inplace(traces: list, ncol: int) -> None:
"""Compute and set extraction heights on Trace objects based on neighbor distances.
For each trace, measures the distance to neighbors at multiple reference
columns (0.1, 0.2, ..., 0.9 of detector width) and uses the maximum.
Parameters
----------
traces : list[Trace]
Trace objects (modified in place to set height attribute)
ncol : int
Number of columns in detector
"""
ntrace = len(traces)
if ntrace == 0:
return
if ntrace == 1:
# Single trace: no neighbors, leave height as None
return
# Reference columns at 0.1, 0.2, ..., 0.9 of detector width
ref_fractions = np.linspace(0.1, 0.9, 9)
ref_cols = (ref_fractions * ncol).astype(int)
for i, t in enumerate(traces):
# Determine which reference columns are within this trace's range
valid_cols = ref_cols[
(ref_cols >= t.column_range[0]) & (ref_cols < t.column_range[1])
]
if len(valid_cols) == 0:
valid_cols = [(t.column_range[0] + t.column_range[1]) // 2]
max_height = 0.0
for x in valid_cols:
y_i = t.y_at_x(x)
if i == 0:
y_next = traces[i + 1].y_at_x(x)
height = abs(y_next - y_i)
elif i == ntrace - 1:
y_prev = traces[i - 1].y_at_x(x)
height = abs(y_i - y_prev)
else:
y_prev = traces[i - 1].y_at_x(x)
y_next = traces[i + 1].y_at_x(x)
height = min(y_i - y_prev, y_next - y_i)
max_height = max(max_height, height)
t.height = max_height
[docs]
def whittaker_smooth(y, lam, axis=0):
"""Whittaker smoother (optimal filter).
Solves: min sum((y - z)^2) + lam * sum((z[i] - z[i-1])^2)
Parameters
----------
y : array
Input data (1D or 2D)
lam : float
Smoothing parameter (higher = smoother)
axis : int
Axis along which to smooth (for 2D arrays)
Returns
-------
z : array
Smoothed data
"""
if y.ndim == 1:
n = len(y)
# Construct tridiagonal matrix: W + lam * D'D
# where D is first-difference matrix
diag_main = np.ones(n) + 2 * lam
diag_main[0] = 1 + lam
diag_main[-1] = 1 + lam
diag_off = -lam * np.ones(n - 1)
A = diags([diag_off, diag_main, diag_off], [-1, 0, 1], format="csc")
return spsolve(A, y)
else:
# Apply along specified axis
return np.apply_along_axis(lambda row: whittaker_smooth(row, lam), axis, y)
[docs]
def fit(x, y, deg, regularization=0):
# order = polyfit1d(y, x, deg, regularization)
if deg == "best":
order = best_fit(x, y)
else:
order = Polynomial.fit(y, x, deg=deg, domain=[]).coef[::-1]
return order
[docs]
def best_fit(x, y):
aic = np.inf
for k in range(5):
coeff_new = fit(x, y, k)
chisq = np.sum((np.polyval(coeff_new, y) - x) ** 2)
aic_new = 2 * k + chisq
if aic_new > aic:
break
else:
coeff = coeff_new
aic = aic_new
return coeff
[docs]
def determine_overlap_rating(xi, yi, xj, yj, mean_cluster_thickness, nrow, ncol, deg=2):
# i and j are the indices of the 2 clusters
i_left, i_right = yi.min(), yi.max()
j_left, j_right = yj.min(), yj.max()
# The number of pixels in the smaller cluster
# this limits the accuracy of the fit
n_min = min(i_right - i_left, j_right - j_left)
# Fit a polynomial to each cluster
order_i = fit(xi, yi, deg)
order_j = fit(xj, yj, deg)
# Get polynomial points inside cluster limits for each cluster and polynomial
y_ii = np.polyval(order_i, np.arange(i_left, i_right))
y_ij = np.polyval(order_i, np.arange(j_left, j_right))
y_jj = np.polyval(order_j, np.arange(j_left, j_right))
y_ji = np.polyval(order_j, np.arange(i_left, i_right))
# difference of polynomials within each cluster limit
diff_i = np.abs(y_ii - y_ji)
diff_j = np.abs(y_ij - y_jj)
ind_i = np.where((diff_i < mean_cluster_thickness) & (y_ji >= 0) & (y_ji < nrow))
ind_j = np.where((diff_j < mean_cluster_thickness) & (y_ij >= 0) & (y_ij < nrow))
# TODO: There should probably be some kind of normaliztion, that scales with the size of the cluster?
# or possibly only use the closest pixels to determine overlap, since the polynomial is badly constrained outside of the bounds.
overlap = min(n_min, len(ind_i[0])) + min(n_min, len(ind_j[0]))
# overlap = overlap / ((i_right - i_left) + (j_right - j_left))
overlap /= 2 * n_min
if i_right < j_left:
overlap *= 1 - (i_right - j_left) / ncol
elif j_right < i_left:
overlap *= 1 - (j_right - i_left) / ncol
overlap_region = [-1, -1]
if len(ind_i[0]) > 0:
overlap_region[0] = np.min(ind_i[0]) + i_left
if len(ind_j[0]) > 0:
overlap_region[1] = np.max(ind_j[0]) + j_left
return overlap, overlap_region
[docs]
def create_merge_array(x, y, mean_cluster_thickness, nrow, ncol, deg, threshold):
n_clusters = list(x.keys())
nmax = len(n_clusters) ** 2
merge = np.zeros((nmax, 5))
for k, (i, j) in enumerate(combinations(n_clusters, 2)):
overlap, region = determine_overlap_rating(
x[i], y[i], x[j], y[j], mean_cluster_thickness, nrow, ncol, deg=deg
)
merge[k] = [i, j, overlap, *region]
merge = merge[merge[:, 2] > threshold]
merge = merge[np.argsort(merge[:, 2])[::-1]]
return merge
[docs]
def update_merge_array(
merge, x, y, j, mean_cluster_thickness, nrow, ncol, deg, threshold
):
j = int(j)
n_clusters = np.array(list(x.keys()))
update = []
for i in n_clusters[n_clusters != j]:
overlap, region = determine_overlap_rating(
x[i], y[i], x[j], y[j], mean_cluster_thickness, nrow, ncol, deg=deg
)
if overlap <= threshold:
# no , or little overlap
continue
update += [[i, j, overlap, *region]]
if len(update) == 0:
return merge
update = np.array(update)
merge = np.concatenate((merge, update))
merge = merge[np.argsort(merge[:, 2])[::-1]]
return merge
[docs]
def calculate_mean_cluster_thickness(x, y):
mean_cluster_thickness = 10 # Default thickness if no clusters found
cluster_thicknesses = []
for cluster in x.keys():
if cluster == 0:
continue # Skip the background cluster if present
# Get all y-coordinates and corresponding x-coordinates for this cluster
y_coords = y[cluster]
x_coords = x[cluster]
# Find unique columns and precompute the x-coordinates for each column
unique_columns = np.unique(y_coords)
column_thicknesses = []
for col in unique_columns:
# Select x-coordinates that correspond to the current column
col_indices = y_coords == col
if np.any(col_indices):
x_in_col = x_coords[col_indices]
thickness = x_in_col.max() - x_in_col.min()
column_thicknesses.append(thickness)
# Average thickness per cluster, if any columns were processed
if column_thicknesses:
cluster_thicknesses.append(np.mean(column_thicknesses))
# Compute the final mean thickness adjusted by the number of clusters
if cluster_thicknesses:
mean_cluster_thickness = (
1.5 * np.mean(cluster_thicknesses) / len(cluster_thicknesses)
)
return mean_cluster_thickness
# origianl version
# def calculate_mean_cluster_thickness(x, y):
# # Calculate mean cluster thickness
# # TODO optimize
# n_clusters = list(x.keys())
# mean_cluster_thickness = 10
# for cluster in n_clusters:
# # individual columns of this cluster
# columns = np.unique(y[cluster])
# delta = 0
# for col in columns:
# # thickness of the cluster in each column
# tmp = x[cluster][y[cluster] == col]
# delta += np.max(tmp) - np.min(tmp)
# mean_cluster_thickness += delta / len(columns)
# mean_cluster_thickness *= 1.5 / len(n_clusters)
# return mean_cluster_thickness
[docs]
def delete(i, x, y, merge):
del x[i], y[i]
merge = merge[(merge[:, 0] != i) & (merge[:, 1] != i)]
return x, y, merge
[docs]
def combine(i, j, x, y, merge, mct, nrow, ncol, deg, threshold):
# Merge pixels
y[j] = np.concatenate((y[j], y[i]))
x[j] = np.concatenate((x[j], x[i]))
# Delete obsolete data
x, y, merge = delete(i, x, y, merge)
merge = merge[(merge[:, 0] != j) & (merge[:, 1] != j)]
# Update merge array
merge = update_merge_array(merge, x, y, j, mct, nrow, ncol, deg, threshold)
return x, y, merge
[docs]
def merge_clusters(
img,
x,
y,
n_clusters,
manual=True,
deg=2,
auto_merge_threshold=0.9,
merge_min_threshold=0.1,
plot_title=None,
):
"""Merge clusters that belong together
Parameters
----------
img : array[nrow, ncol]
the image the order trace is based on
x : dict(int, array(int))
x coordinates of cluster points
y : dict(int, array(int))
y coordinates of cluster points
n_clusters : array(int)
cluster numbers
manual : bool, optional
if True ask before merging clusters (default: True)
deg : int, optional
polynomial degree for fitting (default: 2)
auto_merge_threshold : float, optional
overlap threshold for automatic merging (default: 0.9)
merge_min_threshold : float, optional
minimum overlap to consider merging (default: 0.1)
plot_title : str, optional
title for plots
Returns
-------
x : dict(int: array)
x coordinates of clusters, key=cluster id
y : dict(int: array)
y coordinates of clusters, key=cluster id
n_clusters : array(int)
cluster labels
"""
# Skip all merge computation when merging is disabled
if auto_merge_threshold == 1 and not manual:
n_clusters = list(x.keys())
return x, y, n_clusters
nrow, ncol = img.shape
mct = calculate_mean_cluster_thickness(x, y)
merge = create_merge_array(x, y, mct, nrow, ncol, deg, merge_min_threshold)
if manual:
plt.ion()
plt.figure() # dedicated figure for manual merge mode
k = 0
while k < len(merge):
i, j, overlap, _, _ = merge[k]
i, j = int(i), int(j)
if overlap >= auto_merge_threshold and auto_merge_threshold != 1:
answer = "y"
elif manual:
title = f"Probability: {overlap}"
if plot_title is not None:
title = f"{plot_title}\n{title}"
plot_trace_pair(i, j, x, y, img, deg, title=title)
while True:
if manual:
answer = input("Merge? [y/n]")
if answer in "ynrg":
break
else:
answer = "n"
if answer == "y":
# just merge automatically
logger.info("Merging orders %i and %i", i, j)
x, y, merge = combine(
i, j, x, y, merge, mct, nrow, ncol, deg, merge_min_threshold
)
elif answer == "n":
k += 1
elif answer == "r":
x, y, merge = delete(i, x, y, merge)
elif answer == "g":
x, y, merge = delete(j, x, y, merge)
if manual:
plt.close()
plt.ioff()
n_clusters = list(x.keys())
return x, y, n_clusters
[docs]
def fit_polynomials_to_clusters(x, y, clusters, degree, regularization=0):
"""Fits a polynomial of degree opower to points x, y in cluster clusters
Parameters
----------
x : dict(int: array)
x coordinates seperated by cluster
y : dict(int: array)
y coordinates seperated by cluster
clusters : list(int)
cluster labels, equivalent to x.keys() or y.keys()
degree : int
degree of polynomial fit
Returns
-------
traces : dict(int, array[degree+1])
coefficients of polynomial fit for each cluster
"""
traces = {c: fit(x[c], y[c], degree, regularization) for c in clusters}
return traces
[docs]
def plot_traces(im, traces, title=None):
"""Plot traces and image.
Parameters
----------
im : ndarray
Input image
traces : list[Trace]
Trace objects to plot
title : str, optional
Plot title
"""
plt.figure()
plt.subplot(121)
# Handle non-finite values for plotting
plot_im = np.where(np.isfinite(im), im, np.nan)
valid = np.isfinite(plot_im)
if np.any(valid):
bot, top = np.percentile(plot_im[valid], (1, 99))
if bot >= top:
bot, top = None, None
else:
bot, top = None, None
plt.imshow(plot_im, origin="lower", vmin=bot, vmax=top)
plt.title("Input Image + Trace polynomials")
plt.xlabel("x [pixel]")
plt.ylabel("y [pixel]")
plt.ylim([0, im.shape[0]])
for t in traces:
x = np.arange(*t.column_range, 1)
y = t.y_at_x(x)
plt.plot(x, y)
plt.subplot(122)
plt.imshow(plot_im, origin="lower", vmin=bot, vmax=top)
plt.title("Trace Polynomials")
plt.xlabel("x [pixel]")
plt.ylabel("y [pixel]")
for t in traces:
x = np.arange(*t.column_range, 1)
y = t.y_at_x(x)
plt.plot(x, y)
plt.ylim([0, im.shape[0]])
if title is not None:
plt.suptitle(title)
util.show_or_save("trace_fitted")
[docs]
def plot_trace_pair(i, j, x, y, img, deg, title=""):
"""Plot two trace candidates for merge decision"""
_, ncol = img.shape
trace_i = fit(x[i], y[i], deg)
trace_j = fit(x[j], y[j], deg)
xp = np.arange(ncol)
yi = np.polyval(trace_i, xp)
yj = np.polyval(trace_j, xp)
xmin = min(np.min(x[i]), np.min(x[j])) - 50
xmax = max(np.max(x[i]), np.max(x[j])) + 50
ymin = min(np.min(y[i]), np.min(y[j])) - 50
ymax = max(np.max(y[i]), np.max(y[j])) + 50
yymin = min(max(0, ymin), img.shape[0] - 2)
yymax = min(ymax, img.shape[0] - 1)
xxmin = min(max(0, xmin), img.shape[1] - 2)
xxmax = min(xmax, img.shape[1] - 1)
vmin, vmax = np.percentile(img[yymin:yymax, xxmin:xxmax], (5, 95))
plt.clf()
plt.title(title)
plt.imshow(img, vmin=vmin, vmax=vmax)
plt.plot(xp, yi, "r")
plt.plot(xp, yj, "g")
plt.plot(y[i], x[i], "r.")
plt.plot(y[j], x[j], "g.")
plt.xlim([ymin, ymax])
plt.ylim([xmin, xmax])
util.show_or_save(f"trace_merge_{i}_{j}")
[docs]
def trace(
im,
min_cluster=None,
min_width=None,
filter_x=0,
filter_y=None,
filter_type="boxcar",
noise=0,
noise_relative=0,
degree=4,
border_width=None,
degree_before_merge=2,
regularization=0,
closing_shape=(5, 5),
opening_shape=(2, 2),
plot=False,
plot_title=None,
manual=True,
auto_merge_threshold=0.9,
merge_min_threshold=0.1,
sigma=0,
debug_dir=None,
order_centers: dict[int, float] | None = None,
fibers_per_order: int | None = None,
bundle_centers: dict[int, float] | None = None,
):
"""Identify and trace orders, returning Trace objects.
Parameters
----------
im : array[nrow, ncol]
order definition image
min_cluster : int, optional
minimum cluster size in pixels (default: 500)
filter_x : int, optional
Smoothing width along x-axis/dispersion direction (default: 0, no smoothing).
Useful for noisy data or thin fiber traces.
filter_y : int, optional
Smoothing width along y-axis/cross-dispersion direction (default: auto).
Used to estimate local background. For thin closely-spaced traces, use small values.
filter_type : str, optional
Type of smoothing filter: "boxcar" (default), "gaussian", or "whittaker".
Boxcar is a uniform moving average. Whittaker preserves edges better.
noise : float, optional
Absolute noise threshold added to background (default: 0).
noise_relative : float, optional
Relative noise threshold as fraction of background (default: 0).
If both noise and noise_relative are 0, defaults to 0.001 (0.1%).
degree : int, optional
polynomial degree of the order fit (default: 4)
border_width : int or list of 4 int, optional
Pixels to ignore at image edges for order tracing.
If int, same value applied to all edges.
If list: [top, bottom, left, right] for per-side control.
plot : bool, optional
wether to plot the final order fits (default: False)
manual : bool, optional
wether to manually select clusters to merge (strongly recommended) (default: True)
debug_dir : str, optional
if set, write intermediate images (filtered, background, mask) to this directory
order_centers : dict[int, float], optional
Mapping of order number (m) -> y-position at detector center. If provided,
traces are assigned m values by matching to these centers. Otherwise,
m remains None (to be assigned later from wavecal obase).
fibers_per_order : int, optional
Number of fiber traces per spectral order. When set and order_centers is None,
consecutive traces (sorted by y) are grouped into orders of this size.
Used for instruments like HARPSpol where a Wollaston prism splits each
order into multiple beams.
Returns
-------
list[Trace]
Trace objects with:
- m: assigned from order_centers if provided, else None
- fiber_idx: 1 for single-fiber, or sequential within each order for multi-fiber
- group: None (not yet grouped)
- pos: polynomial coefficients
- column_range: valid column range
- height: computed from neighbor distances (None for single trace)
"""
# Convert to signed integer, to avoid underflow problems
im = np.asanyarray(im)
im = im.astype(int)
if filter_y is None:
col = im[:, im.shape[0] // 2]
col = median_filter(col, 5)
threshold = np.percentile(col, 90)
npeaks = find_peaks(col, height=threshold)[0].size
filter_y = im.shape[0] // (npeaks * 2)
logger.info("Median filter size (y), estimated: %i", filter_y)
elif filter_y <= 0:
raise ValueError(f"Expected filter_y > 0, but got {filter_y}")
if border_width is None:
# find width of orders, based on central column
col = im[:, im.shape[0] // 2]
col = median_filter(col, 5)
idx = np.argmax(col)
width = peak_widths(col, [idx])[0][0]
border_width = int(np.ceil(width))
logger.info("Image border width, estimated: %i", border_width)
# Normalize border_width to [top, bottom, left, right]
if isinstance(border_width, (list, tuple)):
if len(border_width) != 4:
raise ValueError(
f"border_width list must have 4 elements [top, bottom, left, right], "
f"got {len(border_width)}"
)
bw_top, bw_bottom, bw_left, bw_right = [int(b) for b in border_width]
if any(b < 0 for b in (bw_top, bw_bottom, bw_left, bw_right)):
raise ValueError(
f"All border_width values must be >= 0, got {border_width}"
)
elif isinstance(border_width, (int, float, np.integer, np.floating)):
bw = int(border_width)
if bw < 0:
raise ValueError(f"Expected border_width >= 0, but got {bw}")
bw_top = bw_bottom = bw_left = bw_right = bw
else:
raise TypeError(
f"border_width must be int or list of 4 int, got {type(border_width)}"
)
if min_cluster is None:
min_cluster = im.shape[1] // 4
logger.info("Minimum cluster size, estimated: %i", min_cluster)
elif not np.isscalar(min_cluster):
raise TypeError(f"Expected scalar minimum cluster size, but got {min_cluster}")
if min_width is None:
min_width = 0.25
if min_width == 0:
pass
elif isinstance(min_width, (float, np.floating)):
min_width = int(min_width * im.shape[0])
# Validate filter_type
valid_filters = ("boxcar", "gaussian", "whittaker")
if filter_type not in valid_filters:
raise ValueError(
f"filter_type must be one of {valid_filters}, got {filter_type}"
)
# Prepare image for thresholding
# Convert masked values to NaN, interpolate, then back to regular ndarray
if np.ma.is_masked(im):
im_clean = np.ma.filled(im.astype(float), fill_value=np.nan)
kernel = Gaussian2DKernel(x_stddev=1.5, y_stddev=2.5)
im_clean = np.asarray(interpolate_replace_nans(im_clean, kernel))
im_clean = np.nan_to_num(im_clean, nan=0.0)
else:
im_clean = np.asarray(im, dtype=float)
# Select filter function based on filter_type
if filter_type == "boxcar":
def smooth(data, size, axis):
return uniform_filter1d(data, int(size), axis=axis, mode="nearest")
elif filter_type == "gaussian":
def smooth(data, size, axis):
return gaussian_filter1d(data, size, axis=axis)
else: # whittaker
def smooth(data, size, axis):
return whittaker_smooth(data, size, axis=axis)
# Optionally smooth along x (dispersion) to reduce noise
# Applied to both signal and background so we detect y-structure only
if filter_x > 0:
im_clean = smooth(im_clean, filter_x, axis=1)
# Estimate local background by smoothing along y (cross-dispersion)
background = smooth(im_clean, filter_y, axis=0)
# Default to 0.1% relative threshold if neither noise parameter is set
if noise == 0 and noise_relative == 0:
noise_relative = 0.001
logger.info("Using default noise_relative=0.001 (0.1%% of background)")
# Threshold: pixels above local background are signal
# Combines absolute (noise) and relative (noise_relative) thresholds
mask = im_clean > background * (1 + noise_relative) + noise
# Plot cross-section of signal vs background at plot level 2
if plot >= 2: # pragma: no cover
ncol = im_clean.shape[1]
mid = ncol // 2
cols = slice(mid - 25, mid + 25)
signal_profile = np.median(im_clean[:, cols], axis=1)
background_profile = np.median(background[:, cols], axis=1)
threshold_profile = background_profile * (1 + noise_relative) + noise
plt.figure()
plt.plot(signal_profile, label="signal (median of 50 middle cols)")
plt.plot(background_profile, label=f"smoothed ({filter_type}={filter_y})")
plt.plot(
threshold_profile, label=f"threshold (noise={noise}, rel={noise_relative})"
)
plt.xlabel("Row (cross-dispersion)")
plt.ylabel("Counts")
title = "Signal vs background profile"
if plot_title is not None:
title = f"{plot_title}\n{title}"
plt.title(title)
plt.legend()
util.show_or_save("trace_signal_vs_background")
mask_initial = mask.copy()
# remove borders
if bw_top > 0:
mask[:bw_top, :] = False
if bw_bottom > 0:
mask[-bw_bottom:, :] = False
if bw_left > 0:
mask[:, :bw_left] = False
if bw_right > 0:
mask[:, -bw_right:] = False
# remove masked areas with no clusters
mask = np.ma.filled(mask, fill_value=False)
# close gaps inbetween clusters
struct = np.full(closing_shape, 1)
mask = binary_closing(mask, struct, border_value=1)
# remove small lonely clusters
struct = np.full(opening_shape, 1)
# struct = generate_binary_structure(2, 1)
mask = binary_opening(mask, struct)
# Write debug output if requested
if debug_dir is not None:
import os
from astropy.io import fits
os.makedirs(debug_dir, exist_ok=True)
fits.writeto(
os.path.join(debug_dir, "trace_filtered.fits"),
im_clean.astype(np.float32),
overwrite=True,
)
fits.writeto(
os.path.join(debug_dir, "trace_background.fits"),
background.astype(np.float32),
overwrite=True,
)
fits.writeto(
os.path.join(debug_dir, "trace_mask_initial.fits"),
mask_initial.astype(np.uint8),
overwrite=True,
)
fits.writeto(
os.path.join(debug_dir, "trace_mask_final.fits"),
mask.astype(np.uint8),
overwrite=True,
)
logger.info("Wrote debug images to %s", debug_dir)
# label clusters
clusters, _ = label(mask)
n_initial = clusters.max()
logger.info("Found %d clusters initially", n_initial)
# remove small clusters
sizes = np.bincount(clusters.ravel())
mask_sizes = sizes > min_cluster
mask_sizes[0] = True # This is the background, which we don't need to remove
n_too_small = np.sum(~mask_sizes) - 1 # -1 for background
clusters[~mask_sizes[clusters]] = 0
# # Reorganize x, y, clusters into a more convenient "pythonic" format
# # x, y become dictionaries, with an entry for each cluster
# # n is just a list of all cluster labels (ignore cluster == 0)
n = np.unique(clusters)
n = n[n != 0]
x = {i: np.where(clusters == c)[0] for i, c in enumerate(n)}
y = {i: np.where(clusters == c)[1] for i, c in enumerate(n)}
if n_too_small > 0:
logger.info(
"Removed %d clusters smaller than min_cluster=%d, %d remain",
n_too_small,
min_cluster,
len(n),
)
def best_fit_degree(x, y):
L1 = np.sum((np.polyval(np.polyfit(y, x, 1), y) - x) ** 2)
L2 = np.sum((np.polyval(np.polyfit(y, x, 2), y) - x) ** 2)
# aic1 = 2 + 2 * np.log(L1) + 4 / (x.size - 2)
# aic2 = 4 + 2 * np.log(L2) + 12 / (x.size - 3)
if L1 < L2:
return 1
else:
return 2
if sigma > 0:
n_before_sigma = len(x)
cluster_degrees = {i: best_fit_degree(x[i], y[i]) for i in x.keys()}
bias = {i: np.polyfit(y[i], x[i], deg=cluster_degrees[i])[-1] for i in x.keys()}
n = list(x.keys())
yt = np.concatenate([y[i] for i in n])
xt = np.concatenate([x[i] - bias[i] for i in n])
coef = np.polyfit(yt, xt, deg=degree_before_merge)
res = np.polyval(coef, yt)
cutoff = sigma * (res - xt).std()
# DEBUG plot
# uy = np.unique(yt)
# mask = np.abs(res - xt) > cutoff
# plt.plot(yt, xt, ".")
# plt.plot(yt[mask], xt[mask], "r.")
# plt.plot(uy, np.polyval(coef, uy))
# plt.show()
#
m = {
i: np.abs(np.polyval(coef, y[i]) - (x[i] - bias[i])) < cutoff
for i in x.keys()
}
k = max(x.keys()) + 1
for i in range(1, k):
new_img = np.zeros(im.shape, dtype=int)
new_img[x[i][~m[i]], y[i][~m[i]]] = 1
clusters, _ = label(new_img)
x[i] = x[i][m[i]]
y[i] = y[i][m[i]]
if len(x[i]) == 0:
del x[i], y[i]
nnew = np.max(clusters)
if nnew != 0:
xidx, yidx = np.indices(im.shape)
for j in range(1, nnew + 1):
xn = xidx[clusters == j]
yn = yidx[clusters == j]
if xn.size >= min_cluster:
x[k] = xn
y[k] = yn
k += 1
# plt.imshow(clusters, origin="lower")
# plt.show()
n_after_sigma = len(x)
if n_after_sigma != n_before_sigma:
logger.info(
"Sigma clipping: %d -> %d clusters", n_before_sigma, n_after_sigma
)
if plot: # pragma: no cover
plt.figure()
title = "Identified clusters"
if plot_title is not None:
title = f"{plot_title}\n{title}"
plt.title(title)
plt.xlabel("x [pixel]")
plt.ylabel("y [pixel]")
# Sort clusters by mean y-position so we can assign alternating colors
sorted_clusters = sorted(x.keys(), key=lambda i: np.mean(x[i]))
# Use distinct colors that cycle, so adjacent clusters are visually distinct
distinct_colors = [
"#e41a1c",
"#377eb8",
"#4daf4a",
"#984ea3",
"#ff7f00",
"#a65628",
]
from matplotlib.colors import ListedColormap
n_colors = len(distinct_colors)
clusters = np.ma.zeros(im.shape, dtype=int)
for color_idx, i in enumerate(sorted_clusters):
clusters[x[i], y[i]] = (color_idx % n_colors) + 1
clusters[clusters == 0] = np.ma.masked
cmap = ListedColormap(distinct_colors)
plt.imshow(clusters, origin="lower", cmap=cmap, vmin=1, vmax=n_colors)
util.show_or_save("trace_clusters")
# Merge clusters, if there are even any possible mergers left
x, y, n = merge_clusters(
im,
x,
y,
n,
manual=manual,
deg=degree_before_merge,
auto_merge_threshold=auto_merge_threshold,
merge_min_threshold=merge_min_threshold,
plot_title=plot_title,
)
if min_width > 0:
n_before_width = len(x)
sizes = {k: v.max() - v.min() for k, v in y.items()}
mask_sizes = {k: v > min_width for k, v in sizes.items()}
for k, v in mask_sizes.items():
if not v:
del x[k]
del y[k]
n = x.keys()
n_too_narrow = n_before_width - len(x)
if n_too_narrow > 0:
logger.info(
"Removed %d clusters narrower than min_width=%d, %d remain",
n_too_narrow,
min_width,
len(x),
)
logger.info("Fitting polynomials to %d clusters", len(x))
traces_dict = fit_polynomials_to_clusters(x, y, n, degree)
# Sort traces from bottom to top, using relative position
def compare(i, j):
_, xi, i_left, i_right = i
_, xj, j_left, j_right = j
if i_right < j_left or j_right < i_left:
return xi.mean() - xj.mean()
left = max(i_left, j_left)
right = min(i_right, j_right)
return xi[left:right].mean() - xj[left:right].mean()
xp = np.arange(im.shape[1])
keys = [
(c, np.polyval(traces_dict[c], xp), y[c].min(), y[c].max()) for c in x.keys()
]
keys = sorted(keys, key=cmp_to_key(compare))
# Create Trace objects in sorted order
trace_objects = []
for cluster_id, _, _, _ in keys:
pos = traces_dict[cluster_id]
cr = (int(y[cluster_id].min()), int(y[cluster_id].max()) + 1)
trace_objects.append(TraceData(m=None, pos=pos, column_range=cr))
# Compute extraction heights based on trace spacing
_compute_heights_inplace(trace_objects, im.shape[1])
# Assign order numbers, bundles, and fiber indices
_assign_order_and_fiber_inplace(
trace_objects,
order_centers,
im.shape[1],
fibers_per_order=fibers_per_order,
bundle_centers=bundle_centers,
)
if plot: # pragma: no cover
plot_traces(im, trace_objects, title=plot_title)
return trace_objects
[docs]
def select_traces_for_step(
traces: list[TraceData],
fibers_config,
step_name: str,
) -> dict[str, list[TraceData]]:
"""Select which traces to use for a given reduction step.
Looks up fibers_config.use[step_name] to determine selection mode.
Parameters
----------
traces : list[Trace]
All trace objects
fibers_config : FibersConfig or None
Fiber configuration (may be None for single-fiber instruments)
step_name : str
Name of the reduction step (e.g., "science", "curvature")
Returns
-------
selected : dict[str, list[Trace]]
{group_name: [traces]} for each selected group
"""
if not traces:
return {}
# No fiber config means use all traces (single-fiber instrument)
if fibers_config is None:
return {"all": traces}
# No groups/bundles defined means use all traces
if fibers_config.groups is None and fibers_config.bundles is None:
return {"all": traces}
# Determine selection for this step from config
if fibers_config.use is not None:
selection = fibers_config.use.get(
step_name, fibers_config.use.get("default", "groups")
)
else:
selection = "groups"
if selection == "groups":
# Return all traces that have explicit group assignment
grouped = [t for t in traces if t.group is not None]
if grouped:
return {"all": grouped}
# No grouped traces, return all
return {"all": traces}
elif selection == "per_fiber":
# Return traces grouped by fiber_idx for per-fiber processing
result = {}
fiber_indices = {t.fiber_idx for t in traces if t.fiber_idx is not None}
if not fiber_indices:
logger.warning("No fiber_idx set on traces, using all traces")
return {"all": traces}
for idx in sorted(fiber_indices):
idx_traces = [t for t in traces if t.fiber_idx == idx]
idx_traces.sort(key=lambda t: (t.m if t.m is not None else 0))
result[f"fiber_{idx}"] = idx_traces
return result
elif isinstance(selection, list):
# Select specific groups by name (or fiber index if numeric)
result = {}
for name in selection:
# Match by group (compare as string, skip ungrouped traces)
selected = [
t for t in traces if t.group is not None and str(t.group) == name
]
if not selected:
try:
idx = int(name)
selected = [t for t in traces if t.fiber_idx == idx]
except (ValueError, TypeError):
pass
if not selected:
logger.warning("Group '%s' not found in trace data", name)
continue
# Sort by m (order number) for consistent ordering
selected.sort(key=lambda t: (t.m if t.m is not None else 0))
result[name] = selected
if not result:
logger.warning("No valid groups selected, using all traces")
return {"all": traces}
return result
else:
logger.warning("Unknown selection type: %s, using all traces", selection)
return {"all": traces}
[docs]
def group_fibers(
traces: list[TraceData],
fibers_config,
degree: int = 4,
bundle_centers: dict[int, float] | None = None,
) -> list[TraceData]:
"""Merge individual fiber traces into groups according to config.
Takes traces with fiber_idx set (individual fibers) and returns NEW traces
with group set (merged/grouped result). The input traces are not modified.
Parameters
----------
traces : list[Trace]
Individual fiber traces with fiber_idx set and group=None.
Can have m set (from order_centers) or None (to be assigned later).
fibers_config : FibersConfig
Configuration specifying groups or bundles.
degree : int, optional
Polynomial degree for refitted traces (used with "average" merge).
Returns
-------
list[Trace]
New Trace objects with:
- group: set to group name (e.g., "A", "B", "bundle_1")
- fiber_idx: None (merged, no individual identity)
- m: preserved from input traces
- pos: new polynomial from merging
- column_range: intersection of member traces
- height: from config or computed from member traces
Returns empty list if no grouping config is provided.
"""
if not traces:
return []
if fibers_config is None:
return []
if fibers_config.groups is None and fibers_config.bundles is None:
return []
# Group input traces by order number (m)
from collections import defaultdict
traces_by_m = defaultdict(list)
for t in traces:
traces_by_m[t.m].append(t)
# Sort within each order by fiber_idx
for m in traces_by_m:
traces_by_m[m].sort(key=lambda t: t.fiber_idx if t.fiber_idx else 0)
result = []
if fibers_config.groups is not None:
# Named groups with explicit ranges
for group_name, group_cfg in fibers_config.groups.items():
start, end = group_cfg.range
start_idx = start - 1 # Convert to 0-based
end_idx = end - 1 # Half-open, end is exclusive
for m, order_traces in sorted(traces_by_m.items()):
# Select fibers in range for this order
if end_idx > len(order_traces):
logger.warning(
"Group %s range [%d, %d) exceeds fiber count %d in order %s",
group_name,
start,
end,
len(order_traces),
m,
)
end_idx = min(end_idx, len(order_traces))
start_idx = max(start_idx, 0)
if start_idx >= end_idx:
continue
selected = order_traces[start_idx:end_idx]
merged = _merge_trace_objects(selected, group_cfg.merge, degree)
if merged is not None:
# Compute height from group config
height = _compute_group_height_from_traces(selected, group_cfg)
result.append(
TraceData(
m=m,
group=group_name,
fiber_idx=None,
pos=merged.pos,
column_range=merged.column_range,
height=height,
)
)
# Bundles run in addition to (not instead of) named groups: a single
# input fiber can feed both a named-group merge AND a bundle merge,
# producing two distinct output traces with different `group` slots.
if fibers_config.bundles is not None:
bundle_cfg = fibers_config.bundles
bundle_size = bundle_cfg.size
# Group traces by (m, bundle). Each group's fibers are merged
# into one Trace named "bundle_{bundle_id}". `m` and `bundle` are
# independent fields: a multi-order bundled instrument has one
# merged trace per (m, bundle), and the same bundle name repeats
# across orders (uniqueness is on (m, bundle, group)).
traces_by_mb = defaultdict(list)
for t in traces:
traces_by_mb[(t.m, t.bundle)].append(t)
for key in traces_by_mb:
traces_by_mb[key].sort(key=lambda t: t.fiber_idx if t.fiber_idx else 0)
for (m, bundle), members in sorted(
traces_by_mb.items(),
key=lambda kv: (
kv[0][0] if kv[0][0] is not None else float("inf"),
kv[0][1] if kv[0][1] is not None else 0,
),
):
if not members:
continue
if bundle is None:
logger.warning(
"Order %s has %d fibers but no bundle assignment; "
"skipping. Configure bundle_centers_file or set "
"Trace.bundle manually.",
m,
len(members),
)
continue
if len(members) != bundle_size:
logger.warning(
"Order %s bundle %s has %d fibers, expected %d",
m,
bundle,
len(members),
bundle_size,
)
bundle_center_y = (
bundle_centers.get(bundle) if bundle_centers is not None else None
)
merged = _merge_trace_objects(
members, bundle_cfg.merge, degree, bundle_center=bundle_center_y
)
if merged is None:
continue
height = _compute_group_height_from_traces(members, bundle_cfg)
result.append(
TraceData(
m=m,
bundle=bundle,
group=f"bundle_{bundle}",
fiber_idx=None,
pos=merged.pos,
column_range=merged.column_range,
height=height,
)
)
# Sort by (m descending, group)
def sort_key(t):
m = t.m if t.m is not None else float("inf")
return (-m, str(t.group) if t.group else "")
result.sort(key=sort_key)
logger.info("Grouped %d fibers into %d traces", len(traces), len(result))
return result
def _merge_trace_objects(
traces: list[TraceData],
merge_method: str | list[int],
degree: int,
bundle_center: float | None = None,
) -> TraceData | None:
"""Merge multiple traces into one according to merge method.
Parameters
----------
traces : list[Trace]
Traces to merge (must have same m)
merge_method : str or list[int]
"center", "center_weight", "average", or list of 1-based indices.
"center_weight" needs ``bundle_center`` and inverse-distance-weights
every member trace against it, so a missing center fiber doesn't
shift the merged y by a fiber slot.
degree : int
Polynomial degree for refitting
bundle_center : float, optional
Reference y at the detector midpoint for this bundle. Required for
``merge_method == "center_weight"``.
Returns
-------
Trace or None
Merged trace (m and group not set), or None if no valid traces
"""
if not traces:
return None
n = len(traces)
# Find shared column range
col_min = max(t.column_range[0] for t in traces)
col_max = min(t.column_range[1] for t in traces)
if col_min >= col_max:
col_min = min(t.column_range[0] for t in traces)
col_max = max(t.column_range[1] for t in traces)
if merge_method == "center":
idx = n // 2
return TraceData(
m=traces[idx].m,
pos=traces[idx].pos,
column_range=(col_min, col_max),
)
elif merge_method == "center_weight":
if bundle_center is None:
raise ValueError(
"merge='center_weight' requires bundle_centers to be configured"
)
# Weight each detected fiber by 1 / |y - bundle_center| at the
# detector midpoint, then blend their full polynomials with those
# fixed weights. Symmetric flanks cancel; a present center fiber
# dominates; a missing center collapses to the symmetric mean.
x_mid = (col_min + col_max) / 2
eps = 0.1 # px; avoids div-by-zero on a perfect-match fiber
weights = np.array(
[1.0 / max(abs(t.y_at_x(x_mid) - bundle_center), eps) for t in traces]
)
weights /= weights.sum()
x_eval = np.arange(col_min, col_max)
y_values = np.array([t.y_at_x(x_eval) for t in traces])
y_mean = (weights[:, None] * y_values).sum(axis=0)
fit = Polynomial.fit(x_eval, y_mean, deg=degree, domain=[])
coeffs = fit.coef[::-1]
return TraceData(
m=traces[0].m,
pos=coeffs,
column_range=(col_min, col_max),
)
elif merge_method == "average":
if col_min >= col_max:
idx = n // 2
return TraceData(
m=traces[idx].m,
pos=traces[idx].pos,
column_range=traces[idx].column_range,
)
x_eval = np.arange(col_min, col_max)
y_values = np.array([t.y_at_x(x_eval) for t in traces])
y_mean = np.mean(y_values, axis=0)
fit = Polynomial.fit(x_eval, y_mean, deg=degree, domain=[])
coeffs = fit.coef[::-1]
return TraceData(
m=traces[0].m,
pos=coeffs,
column_range=(col_min, col_max),
)
elif isinstance(merge_method, list):
indices = [i - 1 for i in merge_method]
valid = [i for i in indices if 0 <= i < n]
if not valid:
logger.warning("No valid indices in merge method %s", merge_method)
return None
if len(valid) == 1:
idx = valid[0]
return TraceData(
m=traces[idx].m,
pos=traces[idx].pos,
column_range=(col_min, col_max),
)
# Multiple indices: average them
x_eval = np.arange(col_min, col_max)
y_values = np.array([traces[i].y_at_x(x_eval) for i in valid])
y_mean = np.mean(y_values, axis=0)
fit = Polynomial.fit(x_eval, y_mean, deg=degree, domain=[])
coeffs = fit.coef[::-1]
return TraceData(
m=traces[0].m,
pos=coeffs,
column_range=(col_min, col_max),
)
else:
raise ValueError(f"Unknown merge method: {merge_method}")
def _compute_group_height_from_traces(
traces: list[TraceData],
config,
) -> float | None:
"""Compute extraction height for a group of traces.
Parameters
----------
traces : list[Trace]
Traces in this group
config : FiberGroupConfig or FiberBundleConfig
Config with height specification
Returns
-------
float or None
Height in pixels, or None to use default
"""
if config.height is None:
return None
if config.height == "derived":
n = len(traces)
if n < 2:
return None
# Get y-positions at column center
col_min = max(t.column_range[0] for t in traces)
col_max = min(t.column_range[1] for t in traces)
x_center = (col_min + col_max) // 2
y_positions = sorted([t.y_at_x(x_center) for t in traces])
spacings = np.diff(y_positions)
fiber_diameter = np.median(spacings) if len(spacings) > 0 else 0
span = y_positions[-1] - y_positions[0]
return span + fiber_diameter
return float(config.height)
def _verify_trace_ordering(traces: list) -> None:
"""Verify that traces are ordered correctly: m decreases as y increases.
For echelle spectrographs, higher spectral orders (larger m) have shorter
wavelengths and appear lower on the detector (smaller y). So as we go up
in y, m should decrease.
Parameters
----------
traces : list[Trace]
Trace objects to verify
Raises
------
ValueError
If traces are not ordered correctly
"""
if len(traces) < 2:
return
# Get y-position at detector center for each trace
# Use the middle of the common column range
x_mid = np.median([np.mean(t.column_range) for t in traces])
y_positions = [t.y_at_x(x_mid) for t in traces]
m_values = [t.m for t in traces]
# Sort by y to check ordering
sorted_pairs = sorted(zip(y_positions, m_values, strict=False))
# Check that m decreases (or stays same for multi-fiber) as y increases
prev_m = sorted_pairs[0][1]
for y, m in sorted_pairs[1:]:
if m is not None and prev_m is not None and m > prev_m:
logger.warning(
"Trace ordering may be incorrect: m=%d at y=%.1f is higher than m=%d below it. "
"Expected m to decrease as y increases.",
m,
y,
prev_m,
)
break
if m is not None:
prev_m = m