Ridge detection#

import zarr
import zarr.storage
import fsspec
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from scipy.signal import stft
from scipy.ndimage import uniform_filter1d
# List of shot IDs
shot_ids = [23447, 30005, 30021, 30421]  # Add more as needed

# S3 endpoint
endpoint = "https://s3.echo.stfc.ac.uk"
fs = fsspec.filesystem(
    protocol='simplecache',
    target_protocol="s3",
    target_options=dict(anon=True, endpoint_url=endpoint)
)
store_list = []
zgroup_list = []
# Loop through each shot ID
for shot_id in shot_ids:
    url = f"s3://mast/level2/shots/{shot_id}.zarr"
    
    store = zarr.storage.FSStore(fs=fs, url=url)
    store_list.append(store)
    # open or download the Zarr group
    try:
        zgroup_list.append(zarr.open(store, mode='r'))
        print(f"Loaded shot ID {shot_id}")
        # Do something with zgroup here, like listing arrays:
        # print(list(zgroup.array_keys()))
    except Exception as e:
        print(f"Failed to load shot ID {shot_id}: {e}")
Loaded shot ID 23447
Loaded shot ID 30005
Loaded shot ID 30021
Loaded shot ID 30421
# for store in zgroup:
#     root = zarr.open_group(store, mode='r')
mirnov = [xr.open_zarr(store, group="magnetics") for store in store_list]
mirnov[0]
<xarray.Dataset> Size: 32MB
Dimensions:                                   (b_field_pol_probe_cc_channel: 5,
                                               time_mirnov: 261200,
                                               b_field_pol_probe_ccbv_channel: 40,
                                               time: 2612,
                                               b_field_pol_probe_obr_channel: 18,
                                               b_field_pol_probe_obv_channel: 18,
                                               b_field_pol_probe_omv_channel: 3,
                                               b_field_tor_probe_cc_channel: 3,
                                               b_field_tor_probe_saddle_field_channel: 12,
                                               time_saddle: 26120,
                                               b_field_tor_probe_saddle_voltage_channel: 12,
                                               flux_loop_channel: 15)
Coordinates:
  * b_field_pol_probe_cc_channel              (b_field_pol_probe_cc_channel) <U13 260B ...
  * b_field_pol_probe_ccbv_channel            (b_field_pol_probe_ccbv_channel) <U10 2kB ...
  * b_field_pol_probe_obr_channel             (b_field_pol_probe_obr_channel) <U9 648B ...
  * b_field_pol_probe_obv_channel             (b_field_pol_probe_obv_channel) <U9 648B ...
  * b_field_pol_probe_omv_channel             (b_field_pol_probe_omv_channel) <U11 132B ...
  * b_field_tor_probe_cc_channel              (b_field_tor_probe_cc_channel) <U13 156B ...
  * b_field_tor_probe_saddle_field_channel    (b_field_tor_probe_saddle_field_channel) <U11 528B ...
  * b_field_tor_probe_saddle_voltage_channel  (b_field_tor_probe_saddle_voltage_channel) <U15 720B ...
  * flux_loop_channel                         (flux_loop_channel) <U12 720B '...
  * time                                      (time) float64 21kB -0.099 ... ...
  * time_mirnov                               (time_mirnov) float64 2MB -0.09...
  * time_saddle                               (time_saddle) float64 209kB -0....
Data variables:
    b_field_pol_probe_cc_field                (b_field_pol_probe_cc_channel, time_mirnov) float64 10MB ...
    b_field_pol_probe_ccbv_field              (b_field_pol_probe_ccbv_channel, time) float64 836kB ...
    b_field_pol_probe_obr_field               (b_field_pol_probe_obr_channel, time) float64 376kB ...
    b_field_pol_probe_obv_field               (b_field_pol_probe_obv_channel, time) float64 376kB ...
    b_field_pol_probe_omv_voltage             (b_field_pol_probe_omv_channel, time_mirnov) float64 6MB ...
    b_field_tor_probe_cc_field                (b_field_tor_probe_cc_channel, time_mirnov) float64 6MB ...
    b_field_tor_probe_saddle_field            (b_field_tor_probe_saddle_field_channel, time_saddle) float64 3MB ...
    b_field_tor_probe_saddle_voltage          (b_field_tor_probe_saddle_voltage_channel, time_saddle) float64 3MB ...
    flux_loop_flux                            (flux_loop_channel, time) float64 313kB ...
    ip                                        (time) float64 21kB ...
Attributes:
    description:  
    imas:         magnetics
    label:        Plasma Current
    name:         magnetics
    uda_name:     AMC_PLASMA CURRENT
    units:        A
# Extract the DataArrays
ds_list = [m['b_field_pol_probe_omv_voltage'].isel(b_field_pol_probe_omv_channel=1) for m in mirnov]

# Plot all in one figure

for i, ds in enumerate(ds_list):
    plt.figure(i)
    ds.plot(label=f"Shot {i}")
../../_images/8fb1e60b4232a66621a4da5771456af08b0cff13735ac3c64081a39b291c3b9b.png ../../_images/c8af5f987476b86614594f1f669b77bf75a27553c30ef1dd73f3d24eb4d919ad.png ../../_images/1fbfa89779911d8c6396a0fb1669ced0f60b321cca67af666cc58483fa3cff48.png ../../_images/6a9b1fbf7ed2557ab258c9aabbfd730c7179deab62f6d1fec600a7a4f62d6cd6.png

Short-Time Fourier Transform (STFT)#

def plot_stft_spectrogram(    ds,    shot_id=None,    nperseg=2000,    nfft=2000,    tmin=0.1,    tmax=0.46,    fmax_kHz=50,    cmap='jet'):
    """
    Plot STFT spectrogram for a given xarray DataArray `ds`.
    
    Parameters:
    - ds: xarray.DataArray with a 'time_mirnov' coordinate.
    - shot_id: Optional shot ID for labeling.
    - nperseg: Number of points per STFT segment.
    - nfft: Number of FFT points.
    - tmin, tmax: Time range to display (seconds).
    - fmax_kHz: Max frequency to display (kHz).
    - cmap: Colormap name.
    """
    sample_rate = 1 / float(ds.time_mirnov[1] - ds.time_mirnov[0])
    f, t, Zxx = stft(ds.values, fs=int(sample_rate), nperseg=nperseg, nfft=nfft)

    fig, ax = plt.subplots(figsize=(15, 5))
    cax = ax.pcolormesh(
        t, f / 1000, np.abs(Zxx),
        shading='nearest',
        cmap=plt.get_cmap(cmap, 15),
        norm=LogNorm(vmin=1e-5)
    )
    ax.set_ylim(0, fmax_kHz)
    ax.set_xlim(tmin, tmax)
    ax.set_ylabel('Frequency [kHz]')
    ax.set_xlabel('Time [sec]')
    title = f"STFT Spectrogram"
    if shot_id is not None:
        title += f" - Shot {shot_id}"
    ax.set_title(title)
    plt.colorbar(cax, ax=ax, label='Amplitude')
    plt.tight_layout()
[plot_stft_spectrogram(ds_list[i], shot_ids[i]) for i in range(len(ds_list))]
[None, None, None, None]
../../_images/c17b52e0d173da12c6d17afee38b4e094cde8d4da00557feb6fbaf291ea802cb.png ../../_images/ed762b82c3fa9f2d731807eee6d583e86e93019ec85f8f8ad8a2a6d567c7c262.png ../../_images/30cc300e3f10893e6de096d329a639846fc5174aab723321a9627c2fe135d9f8.png ../../_images/5100becfcaefdc610237cdadd64ae136681f621bc0aa8365709b13a94614f44b.png
def plot_stft_histogram(    ds,    shot_id=None,    nperseg=2000,    nfft=2000,    bins=100
):
    """
    Plot a histogram of the absolute STFT amplitude values.

    Parameters:
    - ds: xarray.DataArray with a 'time_mirnov' coordinate.
    - shot_id: Optional shot ID for labeling.
    - nperseg: Number of points per STFT segment.
    - nfft: Number of FFT points.
    - bins: Number of histogram bins.
    """
    sample_rate = 1 / float(ds.time_mirnov[1] - ds.time_mirnov[0])
    f, t, Zxx = stft(ds.values, fs=int(sample_rate), nperseg=nperseg, nfft=nfft)

    plt.figure(figsize=(8, 4))
    plt.hist(np.abs(Zxx.flatten()), bins=bins, log=True)
    plt.xlabel('Amplitude')
    plt.ylabel('Count (log scale)')
    title = 'Histogram of Spectrogram Amplitudes'
    if shot_id is not None:
        title += f" - Shot {shot_id}"
    plt.title(title)
    plt.grid(True)
    plt.tight_layout()
    return f, t, Zxx
f_list, t_list, Zxx_list =  [], [], []
for i, ds in enumerate(ds_list):
    f, t, Zxx = plot_stft_histogram(ds_list[i], shot_ids[i])
    f_list.append(f)
    t_list.append(t)
    Zxx_list.append(Zxx)
../../_images/e90e279a3cc4cc0d7fad806434ba032342e7892b60cf358ea5bb5a45fbf3bade.png ../../_images/1e8e995f70704eddd9b00c006dace17736876b6802084bede341531ba4ea8097.png ../../_images/1f906a19e864a23391bb6399072e82484297d67c5510ebee8132f2aea812af50.png ../../_images/b2d04bae46aeeb9d41047021767121cd8c976e51282d1e1bccb13e1a1202be36.png

Ridge detection#

track dominant frequencies over time of Zxx#

Find the most prominent modes (like chirping MHD modes) in the plasma.

  • We compute the absolute amplitude of Zxx, then for each time slice (Zxx[:, t]), we find the index of the maximum amplitude, and map that back to a frequency using f. The x value in FFT shows a time slice.

  • max_ridges determines the number of max value of ridges to find. So max_ridges=2, picks the top 2 frequencies for each time bin by sorting Zxx amplitudes.

def plot_stft_ridge(
    t,
    f,
    Zxx,
    shot_id=None,
    threshold=1e-3,
    max_ridges=1,
    fmax_kHz=50,
    smooth=True,
    smooth_window=3,
    return_ridges=False,
    min_bin_distance=2
):
    """
    Extract and plot top-N ridges in STFT with better control.
    Prevents overlap and handles low-power regions.
    """
    fmax = fmax_kHz * 1000
    valid_freq_idx = f <= fmax
    f = f[valid_freq_idx]
    Zxx = Zxx[valid_freq_idx, :]

    magnitude = np.abs(Zxx)
    time_bins = magnitude.shape[1]
    freq_bins = magnitude.shape[0]

    # Set low amplitudes to zero
    magnitude[magnitude < threshold] = 0

    ridge_list = []

    used_mask = np.zeros_like(magnitude, dtype=bool)

    plt.figure(figsize=(10, 4))

    for r in range(max_ridges):
        ridge_freq = np.full(time_bins, np.nan)

        for t_idx in range(time_bins):
            # Get amplitudes for this time bin
            col = magnitude[:, t_idx].copy()

            # Mask already used indices (avoid overlapping ridges)
            col[used_mask[:, t_idx]] = 0

            if col.max() >= threshold:
                i = np.argmax(col)
                ridge_freq[t_idx] = f[i] / 1000  # to kHz

                # Mark this bin and neighboring bins as used
                i_start = max(0, i - min_bin_distance)
                i_end = min(freq_bins, i + min_bin_distance + 1)
                used_mask[i_start:i_end, t_idx] = True

        if smooth:
            filled = np.nan_to_num(ridge_freq, nan=0.0)
            smoothed = uniform_filter1d(filled, size=smooth_window)
            ridge_freq = np.where(np.isnan(ridge_freq), np.nan, smoothed)

        plt.plot(t, ridge_freq, label=f'Ridge {r+1}', lw=2)
        ridge_list.append(ridge_freq)

    plt.xlabel("Time [sec]")
    plt.ylabel("Frequency [kHz]")
    title = "STFT Ridge Detection"
    if shot_id is not None:
        title += f" – Shot {shot_id}"
    plt.title(title)
    plt.grid(True)
    plt.legend()
    plt.tight_layout()

    if return_ridges:
        return ridge_list
[plot_stft_ridge(
    t_list[i],
    f_list[i],
    Zxx_list[i],
    shot_id=shot_ids[i],
    threshold=1e-3,
    max_ridges=1,
    fmax_kHz=50,
    smooth=True
) for i in range(len(ds_list))]
[None, None, None, None]
../../_images/b77ae9c7e68f1d5a2885bb0469c6b5b3990c369cbd2fe1b7f26c77523f7047a9.png ../../_images/0487ef9fb88df0244b2be89d786231a0d2fb80f5c5606d736d2d31528e085db3.png ../../_images/5c771d5c7c525d9b01e1f063a4a7e8b9d498f87ecefe3ef38c57d789d464eff2.png ../../_images/86eaa809360a079bcb4c1c9365a8371b3504f17067f956f527510c91abdf737a.png
[plot_stft_ridge(
    t_list[i],
    f_list[i],
    Zxx_list[i],
    shot_id=shot_ids[i],
    threshold=1e-3,
    max_ridges=2,
    fmax_kHz=50,
    smooth=True
) for i in range(len(ds_list))]
[None, None, None, None]
../../_images/261a6dd872b58312a0b6a25f9b53c51cc94f397f27705fc4e1144f6dc81e0b38.png ../../_images/9a9ad1cf1ab741061667f7cc558a89e2a650bf7cb2fed95bf272bab82e36642f.png ../../_images/e62e63c3bc2ef2969d9c30537a2254b89f96bab8bfd8bbac9952e709991b162b.png ../../_images/69822d5928f49e7c53e082f3439fcb408a51c689d56cb71eac8da20c28dd4a94.png
def plot_spectrogram_with_ridges(t, f, Zxx, ridge_freqs_kHz, shot_id=None, vmin=1e-5, fmax_kHz=50):
    """
    Plot the STFT spectrogram with overlayed ridge tracks.
    """
    import matplotlib.pyplot as plt
    from matplotlib.colors import LogNorm

    fig, ax = plt.subplots(figsize=(12, 5))
    f_kHz = f / 1000

    cax = ax.pcolormesh(t, f_kHz, np.abs(Zxx), shading='nearest',
                        norm=LogNorm(vmin=vmin), cmap='jet')

    # Custom color cycle
    ridge_colors = ['black', 'c', 'magenta', 'lime', 'orange']

    for i, ridge in enumerate(ridge_freqs_kHz):
        color = ridge_colors[i % len(ridge_colors)]
        ax.plot(t, ridge, label=f'Ridge {i+1}', lw=2, color=color)

    ax.set_ylim(0, fmax_kHz)
    ax.set_xlim(0.1, t[-1])
    ax.set_xlabel('Time [sec]')
    ax.set_ylabel('Frequency [kHz]')
    title = "STFT Spectrogram with Ridge Overlay"
    if shot_id is not None:
        title += f" – Shot {shot_id}"
    ax.set_title(title)
    plt.colorbar(cax, ax=ax, label="Amplitude")
    ax.legend()
    ax.grid(True)
    plt.tight_layout()
ridges = plot_stft_ridge(
    t_list[0],
    f_list[0],
    Zxx_list[0],
    shot_id=shot_ids[0],
    threshold=1e-3,
    max_ridges=2,
    fmax_kHz=50,
    smooth=True,
    return_ridges=True
)

# Then overlay them:
plot_spectrogram_with_ridges(t_list[0], f_list[0], Zxx_list[0], ridge_freqs_kHz=ridges, shot_id=shot_ids[0])
/tmp/ipykernel_1087196/2863097040.py:32: UserWarning: Creating legend with loc="best" can be slow with large amounts of data.
  plt.tight_layout()
../../_images/261a6dd872b58312a0b6a25f9b53c51cc94f397f27705fc4e1144f6dc81e0b38.png ../../_images/788339521bd66dfc75e40781e6a79281f0a6cfe05364955a4a2aa4d051d3a0b6.png
ridges = plot_stft_ridge(
    t_list[0],
    f_list[0],
    Zxx_list[0],
    shot_id=shot_ids[0],
    threshold=1e-3,
    max_ridges=3,
    fmax_kHz=50,
    smooth=True,
    return_ridges=True
)

# Then overlay them:
plot_spectrogram_with_ridges(t_list[0], f_list[0], Zxx_list[0], ridge_freqs_kHz=ridges, shot_id=shot_ids[0])
../../_images/d3f2158c86f4ebd55613ffa1717932b2e0fd83c4be6729e7c0e1d5c99c7417ab.png ../../_images/64c2ee46c2564c137e896a9f0a62526450bee61faae26b9c27a4e2124ebdb8b5.png