Disruption flat-top labelling with multiple segments#

I realised that a single flat-top can’t capture the true flat-top. In many cases, the flat-top is not a perfect plateau, it can contain spikes. So, I switched to a divide-and-conquer strategy. Instead of forcing one flat-top, I allowed multiple flat-top segments to be detected independently using the same standard deviation check and an added masking rules (e.g., above a minimum Ip, before disruption).

Comnbining these segments was straightforward. Please read the comments in combine_segments method.

Download the data#

Code to download just the summary from CSD3.

for i in {30001..30472}; do                                         
  rsync -avz --progress CSD3:/rds/project/rds-mOlK9qn0PlQ/fairmast/upload-tmp/level2/${i}.zarr/summary/ ~/Downloads/level2_summary/${i}/
done
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter, find_peaks
from scipy.interpolate import interp1d
import pandas as pd
import matplotlib.pyplot as plt
shot_ids_i_like = [30108, 30121, 30178, 30035, 30183, 30086, 30112, 30109, 30209]
# random shots ints
np.random.seed(0)
shot_ids = np.random.randint(30000, 30200, 10).tolist()
shot_ids = list(shot_ids) + shot_ids_i_like
shot_ids = sorted(set(shot_ids))  # remove duplicates and sort

shot_ids
[30009,
 30021,
 30035,
 30036,
 30047,
 30067,
 30086,
 30103,
 30108,
 30109,
 30112,
 30117,
 30121,
 30172,
 30178,
 30183,
 30192,
 30195,
 30209]
def detect_change_points(time, intensity, window_size=50, threshold=2.0):
    """
    Detects change points based on the difference in moving average intensity.

    Args:
    - time (np.array): Time array.
    - intensity (np.array): Intensity array.
    In level 1 data, plasma current is in kA, in level 2 data it is in Amps. So, we divide by 1000 to get kA.
    Otherwise, the gradient will be too high.
    - window_size (int): Size of the window for calculating moving average.
    - threshold (float): Threshold for detecting significant changes.

    Returns:
    - change_points (list): List of times when significant changes are detected.
    """
    moving_avg = np.convolve(intensity/1000, np.ones(window_size) / window_size, mode='valid')
    diff = np.abs(np.diff(moving_avg))

    change_indices = np.where(diff > threshold)[0] + window_size  # Adjust index for valid region
    change_points = time[change_indices]

    # If no change points detected, return None
    if len(change_points) == 0:
        return None

    return change_points
class DisruptionDetector:
    def __init__(
        self,
        ip_threshold=60.0,
        disruption_window=0.05,
        disruption_window_size=20,
        disruption_poly_order=2,
        disruption_prominence=10.0,
        flat_top_window_size=51,
        flat_top_tolerance=0.01,
        flat_top_interp_kind="linear",
        plot=False
    ):
        # disruption detection
        self.ip_threshold = ip_threshold
        self.disruption_window = disruption_window
        self.disruption_window_size = disruption_window_size
        self.disruption_poly_order = disruption_poly_order
        self.disruption_prominence = disruption_prominence

        # flat top
        self.flat_top_window_size = flat_top_window_size
        self.flat_top_tolerance = flat_top_tolerance
        self.flat_top_interp_kind = flat_top_interp_kind

        self.plot = plot

    def _detect_flattop(self, ip, time, td=None):
        """
        Detects the flat-top region of the plasma current using standard deviation based wobbliness check.
        """

        try:
            # Only search before disruption (if given), to avoid false flat region after crash
            # we use masking becasue standard deviation detects post disruption as a flat top.
            if td is None:
                print("td is none. can't mask post disruption zone.")
            if td is not None:
                mask = time < td
                ip = ip[mask]
                time = time[mask]

            # Compute rolling standard deviation to find "flat" regions
            window_pts = self.flat_top_window_size
            ip_series = pd.Series(ip)
            rolling_std = ip_series.rolling(window_pts, center=True).std() # https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.rolling.html

            # Identify indices where the std is below threshold (i.e., flat)
            flat_idxs = np.where(rolling_std < self.flat_top_tolerance)[0]
            #print(f"DEBUG: flat_idxs count = {len(flat_idxs)}")

            if len(flat_idxs) == 0:
                return np.nan, np.nan

            # Split into continuous segments
            diffs = np.diff(flat_idxs)
            split_idx = np.where(diffs > 1)[0]
            segments = np.split(flat_idxs, split_idx + 1)

            # Pick the longest continuous flat segment
            longest = max(segments, key=len)

            # Extract start and end times
            tmin = time[longest[0]]
            tmax = time[longest[-1]]
            return tmin, tmax

        except Exception as e:
            print(f"Flat-top detection failed: {e}")
            return np.nan, np.nan
    def detect_all_flattops(self, ip, time, td, min_duration=0.002, max_iters=10):
        ip = np.array(ip)
        time = np.array(time)
        ip_masked = ip.copy()
        ip_masked[ip < 2e4] = np.nan  # already decided on 20k threshold
        ip_masked[time >= td] = np.nan  # td cut

        segments = []

        for _ in range(max_iters):
            ft_start, ft_end = self._detect_flattop(ip_masked, time, td) # need to modify the flattop detection fucntion to not use td every single time, since it is being masked here.

            if np.isnan(ft_start) or np.isnan(ft_end):
                break

            duration = ft_end - ft_start
            if duration < min_duration:
                print(f"Segment too short: {duration:.4f}s — stopping.")
                break # stop the loop because we are finding flat-tops that are too short and meaningless, as it can be a flat line on the rampup area.

            segments.append((ft_start, ft_end))
            ip_masked[(time >= ft_start) & (time <= ft_end)] = np.nan  # mask used part

        return segments

    def _detect_rampup_simple(self, flat_top_start):
        return 0.0, flat_top_start

    def _detect_rampup(self, ip, time, flat_top_start):
        try:
            if np.isnan(flat_top_start):
                return 0.0, np.nan  # can't proceed if flattop start is NaN, sure there is a better way to do this.

            # Use data before flat-top only
            mask = time <= flat_top_start
            ip = ip[mask]
            time = time[mask]

            # Enforce minimum ramp-up duration
            time_mask = time > 0.05
            ip = ip[time_mask]
            time = time[time_mask]

            # Compute rolling gradient
            ip_smooth = pd.Series(ip).rolling(window=self.flat_top_window_size, center=True).mean()
            grad = np.gradient(ip_smooth)
            grad = np.nan_to_num(grad)

            # Ramp-up ends when gradient drops below a small threshold
            grad_thresh = 0.7 * np.nanmax(grad)
            # candidate_idxs = np.where(grad < grad_thresh)[0]

            # if len(candidate_idxs) == 0:
            #     return 0.0, flat_top_start

            # # Find first place where gradient becomes low (end of strong ramp)
            # t_ramp_end = time[candidate_idxs[-2]]

            high_grad_idxs = np.where(grad > grad_thresh)[0]
            if len(high_grad_idxs) == 0:
                t_ramp_end = flat_top_start
            else:
                t_ramp_end = time[high_grad_idxs[-1]]
            return 0.0, t_ramp_end

        except Exception as e:
            print(f"Ramp-up detection failed: {e}")
            return 0.0, flat_top_start

    def combine_segments(self, segments):
        """
        Sure I can use smart techniques to combines the segments, such as implement a distance parameter.
        But based on my plots, I think we can just take the min and max of the segments. I know this is stupid, but it works.
        The problem is there are huge spikes within the flattop but our iterative algo is able to find correct flat segments in the flattops irrespective of the distance.
        So I think we can just take the min and max of the segments.
        """
        if not segments:
            return np.nan, np.nan
        starts, ends = zip(*segments)
        return min(starts), max(ends)


    def run(self, shot, ip, time):
        # remove early garbage
        mask = time >= 0
        ip = ip[mask]
        time = time[mask]

        change_points = detect_change_points(time, ip, window_size=self.disruption_window_size, threshold=self.disruption_prominence)
        td = change_points[-1] if change_points is not None and len(change_points) > 0 else np.nan

        flattop_segments = self.detect_all_flattops(ip, time, td)

        # # For now, just take the earliest segment as representative
        # if flattop_segments:
        #     ft_start, ft_end = flattop_segments[0]
        # else:
        #     ft_start, ft_end = np.nan, np.nan
        ft_start, ft_end = self.combine_segments(flattop_segments) # combine all segments and calc start and end of flattop

        #ru_start, ru_end = self._detect_rampup(ft_start)  # can stay as-is for now
        ru_start, ru_end = self._detect_rampup(ip, time, ft_start)

        if self.plot:
            self._plot_with_segments(shot, ip, time, flattop_segments, td, ru_start, ru_end)

        return {
            "shot": shot,
            "td": td,
            "flattop_start": ft_start,
            "flattop_end": ft_end,
            "rampup_start": ru_start,
            "rampup_end": ru_end,
            "flattop_segments": flattop_segments,  # all segments so we can plot and see what is going on
        }
# larger window and threshold for standard deviation based flatness check
params = {
    "ip_threshold": 60.0,
    "disruption_window": 0.05,
    "disruption_window_size": 20,
    "disruption_poly_order": 2,
    "disruption_prominence": 10,
    "flat_top_window_size": 100,
    "flat_top_tolerance": 10000,
    "flat_top_interp_kind": "linear",
}
results = []

for shot in shot_ids:
    ds = xr.open_zarr(f"./level2_copy/{shot}", consolidated=False)
    ip = ds["ip"].values
    time = ds["time"].values

    detector = DisruptionDetector(**params)
    result = detector.run(shot, ip, time)
    results.append(result)
    

df = pd.DataFrame(results)
Segment too short: 0.0005s — stopping.
df.head()
shot td flattop_start flattop_end rampup_start rampup_end flattop_segments
0 30009 0.18440 0.13240 0.15915 0.0 0.11990 [(0.1323997106552126, 0.15914971065521266)]
1 30021 0.48880 0.12755 0.44855 0.0 0.11505 [(0.12754970932006857, 0.4485497093200689)]
2 30035 0.21515 0.13165 0.19465 0.0 0.11490 [(0.1316497573852541, 0.14564975738525412), (0...
3 30036 0.35580 0.13730 0.33380 0.0 0.10680 [(0.21904973506927516, 0.3337997350692753), (0...
4 30047 0.62805 0.09480 0.60505 0.0 0.08205 [(0.09479966259002703, 0.6050496625900275)]
def plot_single_shot(shot_id, ip, time, labels_dict):
    colors = ['red', 'blue', 'green', 'orange', 'purple']
    #label_names = ['rampup_start', 'rampup_end', 'flattop_start', 'flattop_end', 'td']
    label_names = [ 'td']
    plt.figure(figsize=(10, 4))
    plt.plot(time, ip)#, label="ip(t)")

    for i, label in enumerate(label_names):
        val = labels_dict.get(label, np.nan)
        if not np.isnan(val):
            plt.axvline(x=val, color=colors[i], linestyle='--', label=f"{label}={val:.3f}")
        else:
            plt.axvline(x=time[0], color=colors[i], linestyle=':', label=f"{label}=nan")
    # plot individual flat tops
    segments = labels_dict.get("flattop_segments", [])
    for i, (start, end) in enumerate(segments):
        plt.axvline(start, color="gray", linestyle="--", alpha=0.5)
        plt.axvline(end, color="gray", linestyle="--", alpha=0.5)
        plt.text((start + end)/2, max(ip)*0.5, f"FT{i}", ha="center", va="bottom", fontsize=8, color="gray")
    for i, (start, end) in enumerate(segments):
        # Get points within the segment for horizontal line level
        segment_mask = (time >= start) & (time <= end)
        if not np.any(segment_mask):
            continue
        segment_ip = ip[segment_mask]
        level = np.median(segment_ip)/2

        # Draw horizontal line and label
        plt.hlines(level, start, end, color="gray", linewidth=2, alpha=0.8)
           
    plt.xlabel("Time (s)")
    plt.ylabel("Plasma Current")
    plt.title(f"Shot {shot_id}")
    plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
    plt.tight_layout()
for result in results:
    shot_id = result["shot"]
    ds = xr.open_zarr(f"./level2_copy/{shot_id}", consolidated=False)
    ip = ds["ip"].values
    time = ds["time"].values

    plot_single_shot(shot_id, ip, time, result)
../../_images/edfc53ded4623d89d3b2e98843e6f42bb298b41cb081c285dbc148e27f485fec.png ../../_images/c2274d7682353575e41c0ecb34eda20c6abbfb7bfc7333ff89d09b03d1b19906.png ../../_images/6d394838462be5b9a5b755381f0ffbe23b5003fe239c5e2b7ce78239ed72e419.png ../../_images/7dc628aabf73030cf98effd52a1ddae8ef1c21e986c333f7d6dccc941b2c4253.png ../../_images/00d83e13e9add99c0681b8a92cb784ea40aa960b63d75406a38008ed1af9f6b6.png ../../_images/cf20c13ec9e4a27f7c6db90912ecd577b97e8df0194d634e0231f89d6c56a82d.png ../../_images/2c12cb1a88ba58990e6bc984b6d44e5a3eee8ff26b650b5f7ee485593a025349.png ../../_images/8addcbd67001d86d856c2a7b0696457e2660e816a3956ca66d92e2e861dd5037.png ../../_images/c60b9ded18ab239c33f411eb90c8f2372c2c4c396708a65ec08a109fc9a5853e.png ../../_images/09afe2dd73d3ab2c5d275e5e8ed60681c52c520a6953604acf7046679f361d75.png ../../_images/a934e150e653279d22ea09af0eb6ddcf924b4739fd669eba62aa85184f30da28.png ../../_images/da20db901c1ea40250cbd34f0d31fbfa1d877536c7cc305ec56adf873846d61b.png ../../_images/24e52781933d2eb4829825dd09dd713dc4c716f7a1b620d8fbfdd9534c07a6b7.png ../../_images/5fc250c61f0efc1a7cb84ca2ae8b5409c6505283b6e059a7bd3437724c7f2c92.png ../../_images/90fc7e82ca3258776ce04761d37c6f99e806086a45a61300fab1922613e7efbe.png ../../_images/ac73135e28a92dfa3176f0e1df6efe55718a7a24b40ad47cb5d10af9b830a23c.png ../../_images/91b2c0b5eaa25317ee6eb36cbfec64d6403278f8bd03b77152761368fca15fa2.png ../../_images/d91a64fd1cb022d47af88ce82d080dc76fcaa9b5d03ea0f73b2e61cd752cc53e.png ../../_images/da51324741f26cca4115137d6e548e1344f257f31056771ae9c7df94f8b1654c.png

Plot the important bits#

def plot_single_shot(shot_id, ip, time, labels_dict):
    colors = ['red', 'blue', 'green', 'orange', 'purple']
    label_names = ['rampup_start', 'rampup_end', 'flattop_start', 'flattop_end', 'td']
    #label_names = [ 'td']
    plt.figure(figsize=(10, 4))
    plt.plot(time, ip)#, label="ip(t)")

    for i, label in enumerate(label_names):
        val = labels_dict.get(label, np.nan)
        if not np.isnan(val):
            plt.axvline(x=val, color=colors[i], linestyle='--', label=f"{label}={val:.3f}")
        else:
            plt.axvline(x=time[0], color=colors[i], linestyle=':', label=f"{label}=nan")
    # horizontal flat-top bar
    ft_start = labels_dict.get("flattop_start", np.nan)
    ft_end = labels_dict.get("flattop_end", np.nan)
    if not np.isnan(ft_start) and not np.isnan(ft_end):
        mask = (time >= ft_start) & (time <= ft_end)
        if np.any(mask):
            ft_level = np.median(ip[mask])*0.5
            plt.hlines(ft_level, ft_start, ft_end, color="gray", linewidth=2, alpha=0.8)
            plt.text((ft_start + ft_end)/2, max(ip)*0.5, f"FT", ha="center", va="bottom", fontsize=8, color="gray")
    plt.xlabel("Time (s)")
    plt.ylabel("Plasma Current")
    plt.title(f"Shot {shot_id}")
    plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
    plt.tight_layout()
for result in results:
    shot_id = result["shot"]
    ds = xr.open_zarr(f"./level2_copy/{shot_id}", consolidated=False)
    ip = ds["ip"].values
    time = ds["time"].values

    plot_single_shot(shot_id, ip, time, result)
../../_images/27cf3aa8fc624a526ccd425afe259d1f8b75cf88a515181c9e56d3e0e3a8f604.png ../../_images/a7d9ab1ce48fef6ef5ca9216b167df5ae1d6e65395020c8b10a58af2d29fe57c.png ../../_images/3a2eead5319d2898e9c1cc26ec1aabbd5bb02a0caeee36c1c26b2eab1db39fc4.png ../../_images/5ed87c7d21a00e4b2bef7a1829abe68a8f47e648a9e66c9247a95129210a4a55.png ../../_images/3f2fb74c8534f0bd870f694f2764f86280c5396eaf27dae7911a9294747d0e0a.png ../../_images/21f210d0519891b18596f702cd44a655898e716720a4df13bb94d17ea50201d4.png ../../_images/08283d412b9d31c8a8904c1b1390b50e97984c18576f7a290071112813f78f1f.png ../../_images/1a7ce6bf66d70641191809bb4ca4a2314ac0c1bf1f16a5ffbf696bee26027466.png ../../_images/0bbe8e2d027c76cea3c574545b3459329911d5a27563a958c2597b6104f7f9d7.png ../../_images/cdad31308eacec6e686a320b703a73222063e3f1d7a3a4fad473d1fb501ae205.png ../../_images/853c2f850fb3151f60885a56d5c0cd488dfd1b50f481eecc3bd2c1d0c46df86e.png ../../_images/9e65766ee9578dc3360eb4a009888a1c1589abcbdc86dd80bcc0dee5dbc5a58c.png ../../_images/9026248f0625c5cb6ce9f5f46254ecd6fc4343e80bd4640c4763707346089081.png ../../_images/426d927073db7228b53c085c8267ed804ce3c27669ff4dc096443b683e13098f.png ../../_images/5bce10badb44ed081816433662896b7f87afa74a3470f412a7eb370043052b06.png ../../_images/be3b85c296a0a6da482121dd96ff2ad88d99f4cd95183a8f7778d0b2c7450fe0.png ../../_images/f1f08e73d8df483af82bbcf74065e4871a8dbd06fc12eb2b9f2cdd4adab1a079.png ../../_images/8dfdc1ca08cdc6a35e8b1ab53823a93ffa6b19b5f0ea86610c9b8bae86bdd568.png ../../_images/263f2624b751a41bdcfae01384300f63872bf69191b111130a6013d9298dd8b5.png