Source code for deep_transit.dt_lightcurve

import io
import math
import warnings
import itertools

import numpy as np
from PIL import Image
from tqdm import tqdm
import argparse

import lightkurve as lk
import matplotlib.pyplot as plt
from astropy.stats import sigma_clip
from wotan import flatten
from .common_utils import warning_on_one_line

warnings.formatwarning = warning_on_one_line  # Raise my own warns
plt.rcdefaults()


[docs]def detrend_light_curve(lc_object, window_length=0.5, edge_cutoff=0.5, break_tolerance=0.5, cval=5.0, sigma_upper=3, sigma_lower=20): """ Detrend a light curve for upcoming transit searching with wotan biweight method and sigma clipping https://github.com/hippke/wotan Parameters ---------- lc_object: `~lightkurve.LightCurve` instance Input light curve object window_length: float The length of the filter window in units of ``time``, default is 0.5 edge_cutoff: float Length (in units of time) to be cut off each edge, default is 0.5 break_tolerance: float Split into segments at breaks longer than that, default is 0.5 cval: float Tuning parameter for the robust estimators, default is 5.0 sigma_upper: float Upper limit of standard deviations for sigma clipping sigma_lower: float Lower limit of standard deviations for sigma clipping Returns ------- flatten_lc : `~lightkurve.LightCurve` instance """ _, trend_flux = flatten( lc_object.time.value, lc_object.flux.value, method='biweight', window_length=window_length, edge_cutoff=edge_cutoff, break_tolerance=break_tolerance, return_trend=True, # Return trend and flattened light curve cval=cval ) with warnings.catch_warnings(): warnings.simplefilter("ignore") flatten_flux = sigma_clip(lc_object.flux.value / trend_flux, sigma_upper=sigma_upper, sigma_lower=sigma_lower, cenfunc=np.nanmedian, stdfunc=np.nanstd, masked=False, axis=0) return lk.LightCurve(time=lc_object.time.value, flux=flatten_flux, meta=lc_object.meta)
def smooth_light_curve(lc_object, N_points): """ Smooth light curve with moving average Parameters ---------- lc_object : `~lightkurve.LightCurve` instance N_points : int Window size of moving average Returns ------- smoothed_lc : `~lightkurve.LightCurve` instance """ df = lc_object.to_pandas() t = df.index.values y = df.flux.rolling(N_points, min_periods=N_points // 2).mean().values return lk.LightCurve(time=t, flux=y) def _light_curve_to_image_array(lc_object, flux_range): """ Convert light curve slice to image array Parameters ---------- lc_object : `~lightkurve.LightCurve` instance flux_range : tuple Flux range in 30-day window in format: (flux_min, flux_max) Returns ------- img_arr : np.ndarray Numpy image array """ exp_time = np.nanmin(np.diff(lc_object.time.value)) with plt.rc_context({'backend': 'agg'}): io_buf = io.BytesIO() io_buf.seek(0) plt.ioff() fig, ax = plt.subplots(1, figsize=(4.16, 4.16), dpi=100, frameon=False) fig.subplots_adjust(left=0, right=1, bottom=0, top=1) ax.set_facecolor('white') ax.axis('off') ax.margins(0, 0) if exp_time > 0.00417: ax.plot(lc_object.time.value, lc_object.flux.value, ls='-', marker='o', lw=72 / fig.dpi / 2, ms=72 / fig.dpi * 2, color='black') else: ax.plot(lc_object.time.value, lc_object.flux.value, '.', ms=72 / fig.dpi, color='black', mew=1.0) smoothed_lc = smooth_light_curve(lc_object.remove_nans(), int(0.0204 / np.nanmin(np.diff(lc_object.time.value)))) ax.plot(smoothed_lc.time.value, smoothed_lc.flux.value, 'grey', lw=1) ax.set_ylim([flux_range[0], flux_range[1]]) fig.savefig(io_buf, format='raw', dpi=100, pad_inches=0) img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8), newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), 4)) io_buf.close() plt.close() return img_arr def _bounding_box_to_time_flux(lc_object, bboxes, flux_range): """ Convert image pixel value to time and flux Parameters ---------- lc_object : `~lightkurve.LightCurve` instance bboxes : list List of bounding boxes. flux_range : tuple Flux range in 30-day window in format: (flux_min, flux_max) Returns ------- transit_masks : list List of bounding box position, in format: [[confidence, x_time, y_flux, w_time, h_flux], ...] """ lc_object = lc_object.remove_nans() time, flux = lc_object.time.value, lc_object.flux.value t_min, t_max, f_min, f_max = np.min(time), np.max(time), flux_range[0], flux_range[1] transit_masks = [] for bbox in bboxes: confidence, x, y, w, h = bbox[:] x_time = x * (t_max - t_min) + t_min # x_time is the middle time of a box y_flux = f_max - y * (f_max - f_min) # y_flux is the middle flux of a box w_time = w * (t_max - t_min) h_flux = h * (f_max - f_min) transit_masks.append([confidence, x_time, y_flux, w_time, h_flux]) return transit_masks def _split_light_curve(lc_object, split_time=10, back_step=3): """ Split a light curve to some slices with indicated time interval and step Parameters ---------- lc_object : `~lightkurve.LightCurve` instance Input light curve object split_time : float The length of each light curve slice in units of ``time`` back_step : float The backward step of each light curve slice in units of ``time`` Yields ------- time_start : float The begin time of each light curve slice. time_stop : float The end time of each light curve slice. """ if (lc_object.time.value[-1] - lc_object.time.value[0]) < split_time: yield lc_object.time.value[0], lc_object.time.value[-1] else: time_start = lc_object.time.value[0] while time_start < lc_object.time.value[-1]: time_stop = time_start + split_time if time_stop > lc_object.time.value[-1]: yield lc_object.time.value[-1] - split_time, lc_object.time.value[-1] break else: yield time_start, time_stop time_start = time_stop - back_step
[docs]class DeepTransit: """ The core class of transit detection. """
[docs] def __init__(self, lc_object=None, time=None, flux=None, flux_err=None, is_flat=False, lk_kwargs={}, flatten_kwargs={}): """ Initial function for receiving an light curve object or a time series. Parameters ---------- lc_object : `~lightkurve.LightCurve` instance Input light curve object time : `~astropy.time.Time` or iterable Time values. They can either be given directly as a `~astropy.time.Time` array or as any iterable that initializes the `~astropy.time.Time` class. flux : `~astropy.units.Quantity` or iterable Flux values for every time point. flux_err : `~astropy.units.Quantity` or iterable Uncertainty on each flux data point. is_flat : bool True when receiving a flattened light, False will use the built-in flatten method. lk_kwargs : dict Keyword arguments of `~lightkurve.LightCurve`. flatten_kwargs : dict Keyword arguments of `detrend_light_curve` method. """ if lc_object is None: if time is None and flux is not None: time = np.arange(len(flux)) lc_object = lk.LightCurve(time=time, flux=flux, flux_err=flux_err, **lk_kwargs).remove_nans() # We are tolerant of missing time format if time is not None and flux is not None: lc_object = lk.LightCurve(time=time, flux=flux, flux_err=flux_err, **lk_kwargs).remove_nans() elif lc_object: if isinstance(lc_object, lk.LightCurve): lc_object = lc_object.remove_nans() else: raise TypeError(f"'lc_obj' should be a `lightkurve.LightCurve object`") lc_object.sort('time') if is_flat is True: self.lc = lc_object else: self.lc = detrend_light_curve(lc_object, **flatten_kwargs)
def _splited_lc_generator(self, backend): time_initial_index = 0 for time_block_index in np.append((np.diff(self.lc.time.value) > 10).nonzero()[0], len(self.lc) - 1): lc_block = self.lc[time_initial_index:time_block_index] time_initial_index = time_block_index + 1 for start_time, stop_time in list(_split_light_curve(lc_block.remove_nans(), split_time=30, back_step=5)): mask = (lc_block.time.value >= start_time) & (lc_block.time.value <= stop_time) selected_lc = lc_block[mask] exp_time = np.nanmin(np.diff(selected_lc.time.value)) if len(selected_lc) < 5 / exp_time: continue flux_min, flux_max = np.nanmin(selected_lc.flux.value) * 1.02 - 0.02 * np.nanmax( selected_lc.flux.value), np.nanmax(selected_lc.flux.value) for t0, t1 in _split_light_curve(selected_lc.remove_nans(), split_time=10, back_step=3): mask = (selected_lc.time.value >= t0) & (selected_lc.time.value <= t1) splited_flatten_lc = selected_lc[mask] exp_time = np.nanmin(np.diff(splited_flatten_lc.time.value)) if len(splited_flatten_lc) < 1 / exp_time: continue img_arr = _light_curve_to_image_array(splited_flatten_lc, (flux_min, flux_max)) image = np.array(Image.fromarray(img_arr).convert("L")) image = backend.trans(image) yield splited_flatten_lc, image, flux_min, flux_max def _data_loader(self, batch_size=1, backend=None): it = iter(self._splited_lc_generator(backend)) while True: with warnings.catch_warnings(): warnings.simplefilter("ignore") chunk = np.array(list(itertools.islice(it, batch_size))) if chunk.size == 0: return yield chunk
[docs] def transit_detection(self, local_model_path, batch_size=2, confidence_threshold=0.6, nms_iou_threshold=0.1, device_str='auto', backend='pytorch'): """ Searching transit signals from a given light curve. Parameters ---------- model_path : str The path of the model file. batch_size : int Batch size for increasing detection speed, especially useful for GPU default value is 2, if using GPU, it can be higher depending on the limitation of the GPU memory. confidence_threshold : float Confidence threshold for transit detection. If None, the value will be obtained from config. Default value is 0.6. nms_iou_threshold : float IOU threshold for NMS algorithm. If None, the value will be obtained from config. Default value is 0.1. device_str : str Device name. If "cuda", it will use GPU. Default is "auto". backend : str Backend of the model. You can choose between "pytorch" or "megengine". Default is "pytorch". Returns ------- final_bboxes : np.ndarray An (N, 5) shape numpy.ndarray of bounding boxes. """ if backend == 'pytorch': from .backend import PytorchBackend backend = PytorchBackend(device_str) else: assert backend == 'megengine' from.mge.backend import MegengineBackend backend = MegengineBackend() backend.load_model(local_model_path) real_unit_bboxes = [] exp_time = np.nanmedian(np.diff(self.lc.time.value)) rough_length = math.ceil(math.ceil((len(self.lc) * exp_time - 30) / 25 + 1) * 4 / batch_size) warnings.warn('The total number of progress bar is the upper limit.') for data in tqdm(self._data_loader(batch_size=batch_size, backend=backend), total=rough_length): lc_data = data[:, 0] flux_min, flux_max = data[:, 2], data[:, 3] predicted_bboxes = backend.inference( data[:, 1], nms_iou_threshold=nms_iou_threshold, confidence_threshold=confidence_threshold) for index, bboxes in enumerate(predicted_bboxes): predicted_bboxes_in_real_unit = _bounding_box_to_time_flux(lc_data[index], bboxes, (flux_min[index], flux_max[index])) real_unit_bboxes += predicted_bboxes_in_real_unit final_bboxes = np.array(backend.nms(real_unit_bboxes, nms_iou_threshold=nms_iou_threshold, confidence_threshold=confidence_threshold)) return final_bboxes
[docs]def plot_lc_with_bboxes(lc_object, bboxes, ax=None, **kwargs): """ Plot light curve with bounding boxes Parameters ---------- lc_object : `~lightkurve.LightCurve` instance bboxes : list or np.ndarray Bounding boxes in shape (N, 5) ax : `~matplotlib.pyplot.axis` instance Axis to plot to. If None, create a new one. kwargs : dict Additional arguments to be passed to `matplotlib.pyplot.plot` Returns ------- ax : `~matplotlib.pyplot.axis` instance The matplotlib axes object. """ with plt.style.context('grayscale'): if ax is None: fig, ax = plt.subplots(1, figsize=(12, 7), constrained_layout=False) ax.plot(lc_object.time.value, lc_object.flux.value, **kwargs) else: ax.plot(lc_object.time.value, lc_object.flux.value, **kwargs) from matplotlib.patches import Rectangle from matplotlib.collections import PatchCollection recs = [] for real_mask in bboxes: rec = Rectangle((real_mask[1] - real_mask[3] / 2, real_mask[2] - real_mask[4] / 2), real_mask[3], real_mask[4], fill=False, color='lime') recs.append(rec) ax.text( real_mask[1] - real_mask[3] / 2, real_mask[2] + real_mask[4] / 2, s=f"{real_mask[0]:.2f}", color="white", verticalalignment="top", bbox=dict(alpha=0.5, color='blue'), clip_on=True ) pc = PatchCollection(recs, facecolor='none', edgecolor='lime', lw=1, zorder=3) ax.add_collection(pc) ax.set_xlabel('Time (day)') ax.set_ylabel('Normalized Flux') return ax
def select_lc_from_bboxes(lc_object, bboxes, fill=1): """ Parameters ---------- bboxes : list lc_object : `~lightkurve.LightCurve` instance fill : float If None, return light curve in the bounding boxes. Otherwise the outer region will be filled with a given value. Default value is 1. Returns ------- """ range_logic = False for bbox in bboxes: t0 = bbox[1] - bbox[3] / 2 t1 = bbox[1] + bbox[3] / 2 y0 = bbox[2] - bbox[4] / 2 y1 = bbox[2] + bbox[4] / 2 range_logic = range_logic | (lc_object.time.value >= t0) & (lc_object.time.value <= t1) & ( lc_object.flux.value >= y0) & (lc_object.flux.value <= y1) if fill is not None: notransiting_lc = lc_object[~range_logic] notransiting_lc['flux'] = fill filled_lc = lc_object[range_logic].append(notransiting_lc) filled_lc.sort('time') return filled_lc else: return lc_object[range_logic] def main(): parser = argparse.ArgumentParser(description='demo for lc detection') parser.add_argument('-lc', type=str, default='11446443', help='light curve number of KIC, used as src') parser.add_argument('-m', '--model_path', type=str, help='model path, will download if empty' ) parser.add_argument('-b', '--batch', type=int, help='batchsize used to inference', default=3) parser.add_argument('--backend', type=str, help='backend of model, use pytorch/megengine', default='pytorch') parser.add_argument('-d', '--device', type=str, help='runtime device of backend', default=None) parser.add_argument('--nms_iou_threshold', type=float, help='nms iou threshold', default=None) parser.add_argument('--confidence_threshold', type=float, help='confidence threshold', default=None) args = parser.parse_args() import matplotlib.pyplot as plt search_result = lk.search_lightcurve('KIC {}'.format(args.lc), author='Kepler') lc = search_result.download_all().stitch() lc = lc[lc.time.value < 135] dt = DeepTransit(lc, is_flatten=False, flatten_kwargs={'window_length': 0.5, 'sigma_upper':3}) flat_lc = detrend_light_curve(lc, window_length=0.5) bboxes = dt.transit_detection(args.model_path, batch_size=args.batch, nms_iou_threshold=args.nms_iou_threshold, confidence_threshold=args.confidence_threshold, device_str=args.device, backend=args.backend) fig, ax = plt.subplots() ax = plot_lc_with_bboxes(flat_lc, bboxes, ax=ax, lw=1) plt.show() if __name__ == '__main__': main()