Percentage based thresholding with change point detection#

# Denoising
import zarr
import zarr.storage
import fsspec
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from scipy.signal import stft
from scipy.signal import find_peaks
from collections import defaultdict
from scipy.ndimage import median_filter, gaussian_filter
from skimage import measure
import pandas as pd
from scipy.ndimage import gaussian_filter1d, label
import ruptures as rpt
# List of shot IDs
shot_ids = [23447, 30005, 30021, 30421]  # Add more as needed

# S3 endpoint
endpoint = "https://s3.echo.stfc.ac.uk"
fs = fsspec.filesystem(
    protocol='simplecache',
    target_protocol="s3",
    target_options=dict(anon=True, endpoint_url=endpoint)
)
store_list = []
zgroup_list = []

# Loop through each shot ID
for shot_id in shot_ids:
    url = f"s3://mast/level2/shots/{shot_id}.zarr"
    
    store = zarr.storage.FSStore(fs=fs, url=url)
    store_list.append(store)
    # open or download the Zarr group
    try:
        zgroup_list.append(zarr.open(store, mode='r'))
        print(f"Loaded shot ID {shot_id}")
        # Do something with zgroup here, like listing arrays:
        # print(list(zgroup.array_keys()))
    except Exception as e:
        print(f"Failed to load shot ID {shot_id}: {e}")
Loaded shot ID 23447
Loaded shot ID 30005
Loaded shot ID 30021
Loaded shot ID 30421
mirnov = [xr.open_zarr(store, group="magnetics") for store in store_list]
ds_list = [m['b_field_pol_probe_omv_voltage'].isel(b_field_pol_probe_omv_channel=1) for m in mirnov]

STFT or short time fourier transform#

def plot_stft_spectrogram(    ds,    shot_id=None,    nperseg=2000,    nfft=2000,    tmin=0.1,    tmax=0.46,    fmax_kHz=50,    cmap='jet'):
    """
    Plot STFT spectrogram for a given xarray DataArray `ds`.
    
    Parameters:
    - ds: xarray.DataArray with a 'time_mirnov' coordinate.
    - shot_id: Optional shot ID for labeling.
    - nperseg: Number of points per STFT segment.
    - nfft: Number of FFT points.
    - tmin, tmax: Time range to display (seconds).
    - fmax_kHz: Max frequency to display (kHz).
    - cmap: Colormap name.
    """
    sample_rate = 1 / float(ds.time_mirnov[1] - ds.time_mirnov[0])
    f, t, Zxx = stft(ds.values, fs=int(sample_rate), nperseg=nperseg, nfft=nfft)

    fig, ax = plt.subplots(figsize=(15, 5))
    cax = ax.pcolormesh(
        t, f / 1000, np.abs(Zxx),
        shading='nearest',
        cmap=plt.get_cmap(cmap, 15),
        norm=LogNorm(vmin=1e-5)
    )
    ax.set_ylim(0, fmax_kHz)
    #ax.set_xlim(tmin, tmax)
    ax.set_ylabel('Frequency [kHz]')
    ax.set_xlabel('Time [sec]')
    title = f"STFT Spectrogram"
    if shot_id is not None:
        title += f" - Shot {shot_id}"
    ax.set_title(title)
    plt.colorbar(cax, ax=ax, label='Amplitude')
    plt.tight_layout()
[plot_stft_spectrogram(ds_list[i], shot_ids[i]) for i in range(len(ds_list))]
[None, None, None, None]
../../_images/5aa9f2f2dbeb3ea7a55b628f64814e7894d2d06a9fc6088a3f079383d969826e.png ../../_images/26fc2e5ca74553a500c823e7f4afab7701250dfc7bc4c8948ee67fca0255aa63.png ../../_images/a84df2cd8ecff086ba7777b0664e428982f7f985ba36f76f1be3ffff5dae8420.png ../../_images/22929c875837679fa57c3c94e81ff238709d9766f8d510e96898bfe096ff1418.png

Thresholding based on percentage#

def plot_amplitude_masking(ds, shot_id=None, nperseg=2000, nfft=2000,
                           sigma=1.0, apply_gaussian=True, apply_mask=True,
                           mask_percentile=60, use_percentage=True,
                           tmin=0.1, tmax=0.46, fmax_kHz=50, cmap='jet'):
    """
    Plots the STFT spectrogram with optional Gaussian blur and masking. This is one function doing everything. No helper needed.

    Parameters:
    - ds: xarray.DataArray with 'time_mirnov'
    - shot_id: Optional shot ID
    - apply_gaussian: Whether to apply Gaussian blur
    - apply_mask: Whether to apply masking
    - mask_percentile: If use_percentage=True, keep top X% of points.
                       Else, mask values below the Xth percentile.
    - use_percentage: Use percentage thresholding instead of percentile
    """

    # Compute STFT
    sample_rate = 1 / float(ds.time_mirnov[1] - ds.time_mirnov[0])
    f, t, Zxx = stft(ds.values, fs=int(sample_rate), nperseg=nperseg, nfft=nfft)
    magnitude = np.abs(Zxx)  # Take magnitude of complex STFT

    # Optional Gaussian blur (for smoothing the spectrogram)
    if apply_gaussian:
        magnitude = gaussian_filter(magnitude, sigma=sigma)

    # Clip negative values just in case blur introduced due to skewed data like sharp gradients
    magnitude = np.clip(magnitude, 0, None)

    # Apply masking
    if apply_mask:
        if use_percentage:
            # Flatten and sort finite values to find cutoff for top X% values
            valid = magnitude[np.isfinite(magnitude)].flatten()
            if valid.size > 0:
                sorted_vals = np.sort(valid)
                cut_index = int((1 - mask_percentile / 100) * len(sorted_vals))
                cutoff_value = sorted_vals[cut_index]
                # Mask all values below cutoff
                magnitude = np.where(magnitude >= cutoff_value, magnitude, np.nan)
        else:
            # Use standard percentile-based thresholding
            threshold = np.percentile(magnitude, mask_percentile)
            magnitude = np.where(magnitude >= threshold, magnitude, np.nan)

    # save a copy of the segmented spectrogram
    segmented_stft = magnitude.copy()


    # Skip if everything got masked (to avoid plotting empty images)
    if not np.any(np.isfinite(magnitude)):
        print(f"Shot {shot_id} — all values masked. Skipping plot.")
        return

    # Plot the spectrogram
    fig, ax = plt.subplots(figsize=(15, 5))
    cax = ax.pcolormesh(
        t, f / 1000, magnitude,
        shading='nearest',
        cmap=plt.get_cmap(cmap, 15),
        norm=LogNorm(vmin=1e-5)
    )
    ax.set_ylim(0, fmax_kHz)
    #ax.set_xlim(tmin, tmax)
    ax.set_ylabel('Frequency [kHz]')
    ax.set_xlabel('Time [sec]')
    title = f"Filtered STFT - Shot {shot_id}" if shot_id else "Filtered STFT"
    ax.set_title(title)
    plt.colorbar(cax, ax=ax, label='Amplitude')
    plt.tight_layout()

    return t, f, segmented_stft, Zxx
results = [plot_amplitude_masking(ds_list[i], shot_ids[i], mask_percentile=1, sigma=0.9) for i in range(len(ds_list))]
t_list, f_list, seg_list, Zxx_list = zip(*results)
../../_images/6d70524da83a9b0b122a2ba9f2c3d1ffaa419b04b2442a0fc51fea68bd7789b6.png ../../_images/cabda6b72f2efdf9d106712410459d7d4008b904205b6e48f6605b8db67753e2.png ../../_images/647c237fe79fd2ebc71c6413f696f2080cc8edc3af725ade19ba75e2d8eafccb.png ../../_images/8a92d9d3e6f436994e6c9cc4f7bbaf0ae496c2f1c0ab581f715b2bb642873708.png

Contour detection#

def plot_spectrogram_with_contours(Zxx, f, t, contours, shot_id=None, vmin=1e-5):

    f_kHz = f / 1000

    fig, ax = plt.subplots(figsize=(12, 5))
    cax = ax.pcolormesh(t, f_kHz, np.abs(Zxx), shading='nearest',
                        norm=LogNorm(vmin=vmin), cmap='jet')

    for contour in contours:
        ax.plot(t[np.clip(contour[:, 1].astype(int), 0, len(t) - 1)],
                f_kHz[np.clip(contour[:, 0].astype(int), 0, len(f) - 1)],
                color='black', lw=0.5)


    ax.set_ylim(0, 50)
    #ax.set_xlim(0.1, 0.46)
    ax.set_xlabel('Time [sec]')
    ax.set_ylabel('Frequency [kHz]')
    title = "STFT with Contour Overlay"
    if shot_id is not None:
        title += f" - Shot {shot_id}"
    ax.set_title(title)
    plt.colorbar(cax, ax=ax, label="Amplitude")
    plt.grid(True)
    plt.tight_layout()
for i in range(len(ds_list)):
    # Use masked STFT (seg_list[i]) to extract binary mask
    binary_mask = np.isfinite(seg_list[i]).astype(float)
    # Get contours at 0.5 (standard threshold for binary masks)
    contours = measure.find_contours(binary_mask, level=0.5)
    plot_spectrogram_with_contours(seg_list[i], f_list[i], t_list[i], contours, shot_id=shot_ids[i])
../../_images/da127f9ac4c5878d8e0f7ac63187dd78dbc8f24630d18da360c2025d3a048805.png ../../_images/fcd5648ba0b9cacdf551d8bef57dd3fed8f9a6b5df7668423b8f8f13112204ed.png ../../_images/a7ec0e64c9b25044e128198e5b88e152f9f0f23ed7a17fe661e7e5c245329a09.png ../../_images/b4970c13e682be2ae345b0251c1dcc365495c225f53d746f2b274cadc60ec4aa.png
min_contour_length = 15  # You can tune this threshold

for i in range(len(ds_list)):
    binary_mask = np.isfinite(seg_list[i]).astype(float)
    contours = measure.find_contours(binary_mask, level=0.5)

    # Filter out short contours
    filtered_contours = [c for c in contours if len(c) >= min_contour_length]

    plot_spectrogram_with_contours(seg_list[i], f_list[i], t_list[i], filtered_contours, shot_id=shot_ids[i])
../../_images/39c9cd5e2c67ad7421315c91e5fb13f99b0877fa1b5684ea3d7959c91134c912.png ../../_images/4bc1333918983e6e07e0ea151a318bcbc51d86fdf127acc95e02b54e0bee8c4d.png ../../_images/ec3d51abe20d42564d9b5929d5cf6f44e01f43875a7c6dda8ade5cd0e85458c3.png ../../_images/a05769f9729dbf405d3578ed1f89032c8e7fe5c752fcf07d13d3c77879d12553.png

Compute Avg Freq#

  • Average frequency → tells you what type of mode is dominant (based on frequency)

  • Average amplitude → tells you how strong the mode is

def amplitude_weighted_avg_freq(stft_amp, f):
    f = f[:, None]  # make it broadcastable: (n_freqs, 1)
    amp = np.nan_to_num(stft_amp, nan=0.0)  # clean up any NaNs
    weighted_sum = np.sum(f * amp, axis=0)  # sum over freqs for each time
    amp_sum = np.sum(amp, axis=0)
    avg_freq = weighted_sum / np.where(amp_sum == 0, np.nan, amp_sum)
    return avg_freq  # shape = (n_times,)
avg_freq_list = []
for i in range(len(ds_list)):
    avg_freq = amplitude_weighted_avg_freq(seg_list[i], f_list[i])
    avg_freq_list.append(avg_freq)
    plt.figure(figsize=(10, 4))
    plt.plot(t_list[i], avg_freq, label="Avg Freq", color="tab:blue")
    plt.xlabel("Time [s]")
    plt.ylabel("Avg Frequency [kHz]")
    plt.title(f"Avg Freq vs Time - Shot {shot_ids[i]}")
    plt.grid(True)
    plt.legend()
    plt.show()
../../_images/d6aa1bbf61b22bbf89d580cedddff505d646bf4b91048ed2b640608363059a31.png ../../_images/7a392f1bb4cbd4d9934c2adadf47cd7f823bfe419769349ab28f78a3dafc7312.png ../../_images/f048de9b4536bb5190e53c9a038b14730ef1adead1fc8db285adc21a89a26023.png ../../_images/62d23fdb13d4a940c6f2adc3a83be4d16eba4990b7bc7bd99b2b9d365528c407.png
for i in range(len(ds_list)):
    # smoothing to reduce noise
    avg_freq_smooth = gaussian_filter1d(avg_freq_list[i], sigma=0.01)

    # Compute gradient (change rate)
    d_avg = np.gradient(avg_freq_smooth)

    # Find index with strongest negative slope (drop in frequency)
    change_idx = np.argmin(d_avg)
    change_time = t_list[i][change_idx]
    plt.figure(figsize=(10, 4))
    plt.plot(t_list[i], avg_freq_list[i], label="Avg Freq", alpha=0.5)
    plt.plot(t_list[i], avg_freq_smooth, label="Smoothed", color="tab:blue")
    plt.axvline(change_time, color="red", linestyle="--", label="Change Point")
    plt.xlabel("Time [s]")
    plt.ylabel("Avg Frequency [kHz]")
    plt.title(f"Avg Freq with Change Point - Shot {shot_ids[i]}")
    plt.legend()
    plt.grid(True)
../../_images/a8fbbbcc5a61e2c888d5d57524ce87ed2d4a61865c72321642322d799e994fd0.png ../../_images/f46475c09395cc85320ef34da6d22a7411418062fce3109a2fb4b422a98c5452.png ../../_images/8e68d7aa6ad1598098911026064a956718c8243d43d331efc1b293af20a5b669.png ../../_images/88aeb531b68a8c1793521b1508087721b39aa2e70a41e84efb89eef7dfe7318a.png

Ruptures library#

# avg_freq_smooth is 1D signal
signal = avg_freq_smooth.reshape(-1, 1)  # make it 2D for ruptures

# Choose algo (here: Pelt with L2 cost)
algo = rpt.Pelt(model="l2").fit(signal)
change_locs = algo.predict(pen=50000)  # adjust penalty to control # of CPs

# Convert to time
change_times = [t_list[i][j] for j in change_locs if j < len(t_list[i])]
plt.plot(t_list[i], avg_freq_smooth)
for ct in change_times:
    plt.axvline(ct, color='red', linestyle='--')
plt.title(f"Ruptures CPD - Shot {shot_ids[i]}")
Text(0.5, 1.0, 'Ruptures CPD - Shot 30421')
../../_images/439d6a6973546931806eaf5278ba6b817208fccb8c05192fa44ffc9ac910a64a.png
def plot_avg_freq_cpd_bursts(ds_list, t_list, avg_freq_list, shot_ids, amp_thresh_percentile=30,
                              sigma=2, pen=500000, min_gap=0.02, min_burst_len=5):
    """
    Burst-wise CPD on amplitude-weighted avg frequency using ruptures.

    Parameters:
    - ds_list: list of 2D STFT amplitude arrays (freq x time)
    - t_list: list of 1D time arrays
    - avg_freq_list: list of 1D avg frequency arrays (from STFT)
    - shot_ids: list of shot IDs
    - amp_thresh_percentile: amplitude threshold for detecting bursts
    - sigma: Gaussian smoothing factor
    - pen: ruptures penalty (model='rbf' assumed)
    - min_gap: min time between CPs to be accepted
    - min_burst_len: min burst length in time bins
    """

    for i in range(len(ds_list)):
        amp = np.nan_to_num(ds_list[i], nan=0.0)
        total_amp = np.sum(amp, axis=0)
        t = t_list[i]

        # Burst mask based on amplitude threshold
        thresh = np.percentile(total_amp, amp_thresh_percentile)
        burst_mask = total_amp > thresh

        # Label contiguous burst regions
        labeled, n_bursts = label(burst_mask)

        # Smooth and prepare average frequency
        avg_freq_smooth = gaussian_filter1d(avg_freq_list[i], sigma=sigma)

        plt.figure(figsize=(10, 4))
        plt.plot(t, avg_freq_list[i], label="Avg Freq", alpha=0.3)
        plt.plot(t, avg_freq_smooth, label="Smoothed", color="tab:blue")

        for b in range(1, n_bursts + 1):
            idx = np.where(labeled == b)[0]
            if len(idx) < min_burst_len:
                continue

            # Extract and normalize burst
            af_segment = avg_freq_smooth[idx]
            af_segment = pd.Series(af_segment).interpolate(limit_direction='both').values
            af_segment = (af_segment - np.nanmean(af_segment)) / (np.nanstd(af_segment) + 1e-8)
            signal = af_segment.reshape(-1, 1)

            # CPD within burst
            try:
                algo = rpt.Pelt(model="rbf").fit(signal)
                cpts = algo.predict(pen=pen)
            except:
                continue  # skip unstable burst

            # Filter CPs by time spacing
            t_seg = t[idx]
            cpts = [j for j in cpts if j < len(t_seg)]
            filtered = [cpts[0]] if cpts else []
            for j in cpts[1:]:
                if t_seg[j] - t_seg[filtered[-1]] > min_gap:
                    filtered.append(j)

            # Plot
            for j in filtered:
                plt.axvline(t_seg[j], color="red", linestyle="--", label="Change Point" if j == filtered[0] else "")

        plt.xlabel("Time [s]")
        plt.ylabel("Avg Frequency [kHz]")
        plt.title(f"Ruptures CPD (burst-wise) - Shot {shot_ids[i]}")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
plot_avg_freq_cpd_bursts(ds_list, t_list, avg_freq_list, shot_ids)
../../_images/e8276622c1dfe73f40617c7ec383fd9b22aba6c188ab4810e68c1146e1f81b30.png ../../_images/15d113a7b2e62c0c23c85f9925a7ca5cc3d6639679e365e92d53f4db7007e268.png ../../_images/47de8afe2bbbcd19729c4f440e199c27e5db6dc0c49e0c3fa191c00cd15e1708.png ../../_images/5dab00cc45a01ca40b2b15724d1d246f2f2bb88db550206b3399d4de045a0546.png
print(f"Shot {shot_ids[i]}: Detected {len(change_times)} change points")
print("Change times:", change_times)
Shot 30421: Detected 72 change points
Change times: [np.float64(0.01000002000004), np.float64(0.02000004000008), np.float64(0.030000060000120003), np.float64(0.04000008000016), np.float64(0.0500001000002), np.float64(0.06000012000024), np.float64(0.07000014000027999), np.float64(0.08000016000032001), np.float64(0.09000018000036), np.float64(0.1000002000004), np.float64(0.11000022000043999), np.float64(0.12000024000048001), np.float64(0.13000026000052), np.float64(0.14000028000056), np.float64(0.1500003000006), np.float64(0.16000032000064002), np.float64(0.17000034000068), np.float64(0.18000036000072), np.float64(0.19000038000076), np.float64(0.20000040000080002), np.float64(0.21000042000084002), np.float64(0.22000044000088), np.float64(0.23000046000092), np.float64(0.24000048000096), np.float64(0.250000500001), np.float64(0.26000052000104), np.float64(0.27000054000108004), np.float64(0.28000056000112), np.float64(0.29000058000116), np.float64(0.3000006000012), np.float64(0.31000062000124), np.float64(0.32000064000128003), np.float64(0.33000066000132), np.float64(0.34000068000136), np.float64(0.3500007000014), np.float64(0.36000072000144), np.float64(0.37000074000148003), np.float64(0.38000076000152), np.float64(0.39000078000156), np.float64(0.4000008000016), np.float64(0.41000082000164), np.float64(0.42000084000168003), np.float64(0.43000086000172), np.float64(0.44000088000176), np.float64(0.4500009000018), np.float64(0.46000092000184), np.float64(0.47000094000188003), np.float64(0.48000096000192), np.float64(0.49000098000196), np.float64(0.500001000002), np.float64(0.5100010200020401), np.float64(0.52000104000208), np.float64(0.53000106000212), np.float64(0.54000108000216), np.float64(0.5500011000022), np.float64(0.56000112000224), np.float64(0.57000114000228), np.float64(0.58000116000232), np.float64(0.59000118000236), np.float64(0.6000012000024), np.float64(0.6100012200024401), np.float64(0.62000124000248), np.float64(0.63000126000252), np.float64(0.64000128000256), np.float64(0.6500013000026), np.float64(0.66000132000264), np.float64(0.67000134000268), np.float64(0.68000136000272), np.float64(0.69000138000276), np.float64(0.7000014000028), np.float64(0.7100014200028401), np.float64(0.72000144000288)]

Extract information about each contour#

def extract_contour_features(contour, t, f, Zxx):
    """Extract features from a single contour."""
    time_idx = np.clip(contour[:, 1].astype(int), 0, len(t) - 1)
    freq_idx = np.clip(contour[:, 0].astype(int), 0, len(f) - 1)
    
    times = t[time_idx]
    freqs = f[freq_idx]# / 1000  # in kHz
    amps = np.abs(Zxx[freq_idx, time_idx])
    
    # Handle degenerate contours
    if len(times) < 2:
        return None
    
    # Features
    duration = times.max() - times.min()
    freq_span = freqs.max() - freqs.min()
    slope = np.polyfit(times, freqs, 1)[0]
    avg_amp = np.mean(amps)
    max_amp = np.max(amps)
    
    return {
        'duration': duration,
        'freq_span': freq_span,
        'slope': slope,
        'avg_amp': avg_amp,
        'max_amp': max_amp,
        'start_time': times.min(),
        'end_time': times.max(),
        'start_freq': freqs[0],
        'end_freq': freqs[-1],
        'length': len(times),
    }
# Initialize a list to collect all feature DataFrames
all_feature_dfs = []
# Extract features from all contours
# Loop through each shot to extract contour features
for i in range(len(ds_list)):
    #  Binary mask from segmented STFT (NaNs were introduced during masking)
    binary_mask = np.isfinite(seg_list[i]).astype(float)

    #  Detect contours at 0.5 level (standard threshold for binary masks)
    contours = measure.find_contours(binary_mask, level=0.5)

    #  Extract features for each contour using helper
    features = [extract_contour_features(c, t_list[i], f_list[i], Zxx_list[i]) for c in contours]

    #  Filter out any invalid results (None entries)
    features = [f for f in features if f is not None]

    for f in features:
        f['shot_id'] = shot_ids[i]  # Add shot ID to each feature dict

    all_feature_dfs.extend(features)

# Concatenate into a single DataFrame
df_all = pd.DataFrame(all_feature_dfs)
df_all.head()
duration freq_span slope avg_amp max_amp start_time end_time start_freq end_freq length shot_id
0 0.006 0.0000 0.000000 0.027723 0.033559 0.180000 0.186000 0.0000 0.0000 5 23447
1 0.062 9249.9815 -90778.241244 0.014232 0.068434 0.404001 0.466001 0.0000 0.0000 155 23447
2 0.006 4999.9900 -157204.966494 0.018831 0.075571 0.342001 0.348001 18999.9620 18999.9620 47 23447
3 0.062 5749.9885 32883.460780 0.016402 0.066716 0.352001 0.414001 20499.9590 20499.9590 133 23447
4 0.004 999.9980 -111280.042683 0.026360 0.046078 0.434001 0.438001 15749.9685 15749.9685 13 23447
df_clean = df_all.dropna()

###### More filtering needed. A lot of garbage contour still left #######
df_clean = df_clean[(df_clean['duration'] > 0.004) & 
                    (df_clean['avg_amp'] > 1e-4) &
                    (df_clean['length'] >= 5)&
                    (df_clean['freq_span'] > 0.1) ]

# select all rows with shot_id 23447
df_clean[df_clean['shot_id'] == 23447].head()
duration freq_span slope avg_amp max_amp start_time end_time start_freq end_freq length shot_id
1 0.062 9249.9815 -90778.241244 0.014232 0.068434 0.404001 0.466001 0.0000 0.0000 155 23447
2 0.006 4999.9900 -157204.966494 0.018831 0.075571 0.342001 0.348001 18999.9620 18999.9620 47 23447
3 0.062 5749.9885 32883.460780 0.016402 0.066716 0.352001 0.414001 20499.9590 20499.9590 133 23447
4 0.004 999.9980 -111280.042683 0.026360 0.046078 0.434001 0.438001 15749.9685 15749.9685 13 23447
5 0.018 9499.9810 -289641.192242 0.021990 0.063202 0.344001 0.362001 29999.9400 29999.9400 97 23447
len(df_clean)
197
df_clean.describe()
duration freq_span slope avg_amp max_amp start_time end_time start_freq end_freq length shot_id
count 197.000000 197.000000 197.000000 197.000000 197.000000 197.000000 197.000000 197.000000 197.000000 197.000000 197.000000
mean 0.009492 3944.154548 -40219.120192 0.022603 0.053714 0.363706 0.373199 46320.973348 46320.973348 48.837563 29168.340102
std 0.012287 4251.873305 109788.621468 0.005910 0.014438 0.102645 0.102751 34761.472759 34761.472759 75.147501 2389.785404
min 0.004000 249.999500 -475443.906896 0.009941 0.025553 0.060000 0.068000 0.000000 0.000000 6.000000 23447.000000
25% 0.004000 1249.997500 -69270.556250 0.018105 0.041553 0.296001 0.302001 15249.969500 15249.969500 15.000000 30005.000000
50% 0.006000 2249.995500 -13449.313291 0.022782 0.053496 0.332001 0.340001 41499.917000 41499.917000 27.000000 30021.000000
75% 0.010000 5249.989500 12578.566038 0.026939 0.063753 0.420001 0.436001 78499.843000 78499.843000 53.000000 30421.000000
max 0.084000 27749.944500 181074.502167 0.040106 0.087295 0.658001 0.668001 116999.766000 116999.766000 707.000000 30421.000000

Simple clustering#

from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
features = df_clean[['duration', 'freq_span', 'slope', 'avg_amp', 'max_amp']]
scaler = StandardScaler()
X_scaled = scaler.fit_transform(features)

kmeans = KMeans(n_clusters=3, random_state=42)
df_clean['cluster'] = kmeans.fit_predict(X_scaled)
def plot_clustered_contours_by_shot(df, t_range=(0.1, 0.46), f_range=(0, 50), cluster_names=None):
    """
    Plots clustered mode segments (time vs frequency) separately for each shot_id.

    Parameters:
    - df: DataFrame with extracted features and a 'cluster' column.
    - t_range: Tuple (tmin, tmax) for x-axis.
    - f_range: Tuple (fmin, fmax) for y-axis.
    - cluster_names: Optional dict mapping cluster numbers to names.
    """
    cluster_colors = ['red', 'green', 'blue', 'purple', 'orange']
    shots = df['shot_id'].unique()

    for shot in shots:
        plt.figure(figsize=(12, 6))
        shot_data = df[df['shot_id'] == shot]

        for cluster_id in sorted(shot_data['cluster'].dropna().unique()):
            cluster_df = shot_data[shot_data['cluster'] == cluster_id]

            for _, row in cluster_df.iterrows():
                plt.hlines(
                    y=row['start_freq'] / 1000,
                    xmin=row['start_time'],
                    xmax=row['end_time'],
                    colors=cluster_colors[int(cluster_id)],
                    linewidth=2,
                    label=f'Cluster {cluster_id}' if f'Cluster {cluster_id}' not in plt.gca().get_legend_handles_labels()[1] else ""
                )

        plt.title(f"Clustered Mode Segments – Shot {shot}")
        plt.xlabel("Time [sec]")
        plt.ylabel("Frequency [kHz]")
        plt.xlim(*t_range)
        plt.ylim(*f_range)
        handles, labels = plt.gca().get_legend_handles_labels()
        if cluster_names:
            labels = [cluster_names.get(int(label.split()[-1]), label) for label in labels]
        plt.legend(handles, labels)
        plt.grid(True)
        plt.tight_layout()
        plt.show()
plot_clustered_contours_by_shot(df_clean)
../../_images/5352f3c8031e50957b5a81af53e24ef25aa63bccda56ab7100a62beb77bb8f99.png ../../_images/6afc728e4886d17cab8a0b43dfc9fa5c2acc5804f481891407f2ed6b78a5726f.png ../../_images/b923c622c06471181bbfd00f0f39c9ab8d6c2bc6d88fde1bc9179a79cfaeb288.png ../../_images/fc037629ff101c74a0b810207aea903742683797f0ce8760c113948e9ad4d1d3.png