PCA with K means#

Download data from CSD3 in case S3 bucket is offline#

#!/bin/bash

# Define shot IDs
shot_ids=(23447 30005 30021 30421)

# Remote info
remote_alias="CSD3"
remote_dir="/rds/project/rds-mOlK9qn0PlQ/fairmast/upload-tmp/level2"

# Loop through each shot_id and use rsync to download
for shot_id in "${shot_ids[@]}"; do
    echo "Downloading ${shot_id}.zarr from ${remote_alias}..."
    rsync -avz "${remote_alias}:${remote_dir}/${shot_id}.zarr" ./
done

echo "All downloads complete."
# Denoising
import zarr
import zarr.storage
import fsspec
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib import cm
from scipy.signal import stft
from scipy.signal import find_peaks
from collections import defaultdict
from scipy.ndimage import median_filter, gaussian_filter
from skimage import measure
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from pathlib import Path
from scipy.signal import savgol_filter
# List of shot IDs
shot_ids = [23447, 30005, 30021, 30421]  # Add more as needed

# S3 endpoint
endpoint = "https://s3.echo.stfc.ac.uk"
fs = fsspec.filesystem(
    protocol='simplecache',
    target_protocol="s3",
    target_options=dict(anon=True, endpoint_url=endpoint)
)
store_list = []
zgroup_list = []

# Loop through each shot ID
for shot_id in shot_ids:
    path = Path(f"shots/{shot_id}.zarr")

    if not path.exists():
        print(f"Local path {path} not found.")
        continue

    store = zarr.DirectoryStore(str(path))
    store_list.append(store)

    try:
        zgroup = zarr.open(store, mode='r')
        zgroup_list.append(zgroup)
        print(f"Loaded shot ID {shot_id}")
        # Example: print available array keys
        # print(list(zgroup.array_keys()))
    except Exception as e:
        print(f"Failed to load shot ID {shot_id}: {e}")
Loaded shot ID 23447
Loaded shot ID 30005
Loaded shot ID 30021
Loaded shot ID 30421
# DOWNLOAD FROM S3 BUCKET
# store_list = []
# zgroup_list = []

# # Loop through each shot ID
# for shot_id in shot_ids:
#     url = f"s3://mast/level2/shots/{shot_id}.zarr"
    
#     store = zarr.storage.FSStore(fs=fs, url=url)
#     store_list.append(store)
#     # open or download the Zarr group
#     try:
#         zgroup_list.append(zarr.open(store, mode='r'))
#         print(f"Loaded shot ID {shot_id}")
#         # Do something with zgroup here, like listing arrays:
#         # print(list(zgroup.array_keys()))
#     except Exception as e:
#         print(f"Failed to load shot ID {shot_id}: {e}")
mirnov = [xr.open_zarr(store, group="magnetics") for store in store_list]
ds_list = [m['b_field_pol_probe_omv_voltage'].isel(b_field_pol_probe_omv_channel=1) for m in mirnov]

STFT or short time fourier transform#

def plot_stft_spectrogram(    ds,    shot_id=None,    nperseg=2000,    nfft=2000,    tmin=0.1,    tmax=0.46,    fmax_kHz=50,    cmap='jet'):
    """
    Plot STFT spectrogram for a given xarray DataArray `ds`.
    
    Parameters:
    - ds: xarray.DataArray with a 'time_mirnov' coordinate.
    - shot_id: Optional shot ID for labeling.
    - nperseg: Number of points per STFT segment.
    - nfft: Number of FFT points.
    - tmin, tmax: Time range to display (seconds).
    - fmax_kHz: Max frequency to display (kHz).
    - cmap: Colormap name.
    """
    sample_rate = 1 / float(ds.time_mirnov[1] - ds.time_mirnov[0])
    f, t, Zxx = stft(ds.values, fs=int(sample_rate), nperseg=nperseg, nfft=nfft)

    fig, ax = plt.subplots(figsize=(15, 5))
    cax = ax.pcolormesh(
        t, f / 1000, np.abs(Zxx),
        shading='nearest',
        cmap=plt.get_cmap(cmap, 15),
        norm=LogNorm(vmin=1e-5)
    )
    ax.set_ylim(0, fmax_kHz)
    #ax.set_xlim(tmin, tmax)
    ax.set_ylabel('Frequency [kHz]')
    ax.set_xlabel('Time [sec]')
    title = f"STFT Spectrogram"
    if shot_id is not None:
        title += f" - Shot {shot_id}"
    ax.set_title(title)
    plt.colorbar(cax, ax=ax, label='Amplitude')
    plt.tight_layout()
[plot_stft_spectrogram(ds_list[i], shot_ids[i]) for i in range(len(ds_list))]
[None, None, None, None]
../../_images/5aa9f2f2dbeb3ea7a55b628f64814e7894d2d06a9fc6088a3f079383d969826e.png ../../_images/26fc2e5ca74553a500c823e7f4afab7701250dfc7bc4c8948ee67fca0255aa63.png ../../_images/a84df2cd8ecff086ba7777b0664e428982f7f985ba36f76f1be3ffff5dae8420.png ../../_images/22929c875837679fa57c3c94e81ff238709d9766f8d510e96898bfe096ff1418.png

Thresholding based on percentage#

def plot_amplitude_masking(ds, shot_id=None, nperseg=2000, nfft=2000,
                           sigma=1.0, apply_gaussian=False, apply_mask=True,
                           mask_percentile=60, use_percentage=True,
                           tmin=0.1, tmax=0.46, fmax_kHz=50, cmap='jet', apply_savgol_filter = False):
    """
    Plots the STFT spectrogram with optional Gaussian blur and masking. This is one function doing everything. No helper needed.

    Parameters:
    - ds: xarray.DataArray with 'time_mirnov'
    - shot_id: Optional shot ID
    - apply_gaussian: Whether to apply Gaussian blur
    - apply_mask: Whether to apply masking
    - mask_percentile: If use_percentage=True, keep top X% of points.
                       Else, mask values below the Xth percentile.
    - use_percentage: Use percentage thresholding instead of percentile
    """

    # Compute STFT
    sample_rate = 1 / float(ds.time_mirnov[1] - ds.time_mirnov[0])
    f, t, Zxx = stft(ds.values, fs=int(sample_rate), nperseg=nperseg, nfft=nfft)
    magnitude = np.abs(Zxx)  # Take magnitude of complex STFT

    # Optional Gaussian blur (for smoothing the spectrogram)
    if apply_gaussian:
        magnitude = gaussian_filter(magnitude, sigma=sigma)

    # savgol filter
    if apply_savgol_filter:
        # windows length Must be an odd integer.
        # 
        window_length = max(5, min(11, len(f) // 3 * 2 + 1))
        magnitude = savgol_filter(magnitude, window_length=7, polyorder=2)

    # Clip negative values just in case blur introduced due to skewed data like sharp gradients
    magnitude = np.clip(magnitude, 0, None)

    # Apply masking
    if apply_mask:
        if use_percentage:
            # Flatten and sort finite values to find cutoff for top X% values
            valid = magnitude[np.isfinite(magnitude)].flatten()
            if valid.size > 0:
                sorted_vals = np.sort(valid)
                cut_index = int((1 - mask_percentile / 100) * len(sorted_vals))
                cutoff_value = sorted_vals[cut_index]
                # Mask all values below cutoff
                magnitude = np.where(magnitude >= cutoff_value, magnitude, np.nan)
        else:
            # Use standard percentile-based thresholding
            threshold = np.percentile(magnitude, mask_percentile)
            magnitude = np.where(magnitude >= threshold, magnitude, np.nan)

    # save a copy of the segmented spectrogram
    segmented_stft = magnitude.copy()


    # Skip if everything got masked (to avoid plotting empty images)
    if not np.any(np.isfinite(magnitude)):
        print(f"Shot {shot_id} — all values masked. Skipping plot.")
        return

    # Plot the spectrogram
    fig, ax = plt.subplots(figsize=(15, 5))
    cax = ax.pcolormesh(
        t, f / 1000, magnitude,
        shading='nearest',
        cmap=plt.get_cmap(cmap, 15),
        norm=LogNorm(vmin=1e-5)
    )
    ax.set_ylim(0, fmax_kHz)
    #ax.set_xlim(tmin, tmax)
    ax.set_ylabel('Frequency [kHz]')
    ax.set_xlabel('Time [sec]')
    title = f"Filtered STFT - Shot {shot_id}" if shot_id else "Filtered STFT"
    ax.set_title(title)
    plt.colorbar(cax, ax=ax, label='Amplitude')
    plt.tight_layout()

    return t, f, segmented_stft, Zxx
results = [plot_amplitude_masking(ds_list[i], shot_ids[i], mask_percentile=1, sigma=0.9, apply_gaussian=False, apply_savgol_filter=True) for i in range(len(ds_list))]
t_list, f_list, seg_list, Zxx_list = zip(*results)
../../_images/d552c1fec15c27c72c9c72bbce7e341835fab8035204ddaff2a0b1810a2a2528.png ../../_images/64044cc7a1a134458695eba158a3da552f3e9dc9eabd5c1405705c91e0ddfab1.png ../../_images/95bedea897b8b18650da0e8aea3c3bf4b35307335cdcb56124575c8852135c7b.png ../../_images/8b6b2d901711f8b400db6e16b9d94537ff597275bdb9783587ad8d3a88cdf5b6.png

Contour detection#

def plot_spectrogram_with_contours(Zxx, f, t, contours, shot_id=None, vmin=1e-5):

    f_kHz = f / 1000

    fig, ax = plt.subplots(figsize=(12, 5))
    cax = ax.pcolormesh(t, f_kHz, np.abs(Zxx), shading='nearest',
                        norm=LogNorm(vmin=vmin), cmap='jet')

    for contour in contours:
        ax.plot(t[np.clip(contour[:, 1].astype(int), 0, len(t) - 1)],
                f_kHz[np.clip(contour[:, 0].astype(int), 0, len(f) - 1)],
                color='black', lw=0.5)


    ax.set_ylim(0, 70)
    #ax.set_xlim(0.1, 0.46)
    ax.set_xlabel('Time [sec]')
    ax.set_ylabel('Frequency [kHz]')
    title = "STFT with Contour Overlay"
    if shot_id is not None:
        title += f" - Shot {shot_id}"
    ax.set_title(title)
    plt.colorbar(cax, ax=ax, label="Amplitude")
    plt.grid(True)
    plt.tight_layout()
for i in range(len(ds_list)):
    # Use masked STFT (seg_list[i]) to extract binary mask
    binary_mask = np.isfinite(seg_list[i]).astype(float)
    # Get contours at 0.5 (standard threshold for binary masks)
    contours = measure.find_contours(binary_mask, level=0.5)
    plot_spectrogram_with_contours(seg_list[i], f_list[i], t_list[i], contours, shot_id=shot_ids[i])
../../_images/2a5f4c44b925bd4840aeddc323744e51044e2019f9d70564b73737336ae3896a.png ../../_images/c52995d83b93df977f4565c8b62e50321c7c6f7c4982bcb4e31f5e76d5057ed7.png ../../_images/231bb6f58a8751988858a08d78ad6604d69dc8806150cfdcb40cd8bd32011060.png ../../_images/c7a0cb60ac9195a2777702ff3b27761305a7bd3f597dc1270f97e41a283c8e74.png

Extract information about each contour#

def extract_contour_features(contour, t, f, Zxx):
    """Extract features from a single contour."""
    time_idx = np.clip(contour[:, 1].astype(int), 0, len(t) - 1)
    freq_idx = np.clip(contour[:, 0].astype(int), 0, len(f) - 1)
    
    times = t[time_idx]
    freqs = f[freq_idx]# / 1000  # in kHz
    amps = np.abs(Zxx[freq_idx, time_idx])
    
    # Handle degenerate contours
    if len(times) < 2:
        return None
    
    # Features
    duration = times.max() - times.min()
    freq_span = freqs.max() - freqs.min()
    slope = np.polyfit(times, freqs, 1)[0]
    avg_amp = np.mean(amps)
    max_amp = np.max(amps)
    
    return {
        'duration': duration,
        'freq_span': freq_span,
        'slope': slope,
        'avg_amp': avg_amp,
        'max_amp': max_amp,
        'start_time': times.min(),
        'end_time': times.max(),
        'start_freq': freqs[0],
        'end_freq': freqs[-1],
        'length': len(times),
    }
# Initialize a list to collect all feature DataFrames
all_feature_dfs = []
# Extract features from all contours
# Loop through each shot to extract contour features
for i in range(len(ds_list)):
    #  Binary mask from segmented STFT (NaNs were introduced during masking)
    binary_mask = np.isfinite(seg_list[i]).astype(float)

    #  Detect contours at 0.5 level (standard threshold for binary masks)
    contours = measure.find_contours(binary_mask, level=0.5)

    # #  Extract features for each contour using helper
    # features = [extract_contour_features(c, t_list[i], f_list[i], Zxx_list[i]) for c in contours]

    # #  Filter out any invalid results (None entries)
    # features = [f for f in features if f is not None]

    # for f in features:
    #     f['shot_id'] = shot_ids[i]  # Add shot ID to each feature dict
    
    features = []
    for j, c in enumerate(contours):  # add j for contour index
        feat = extract_contour_features(c, t_list[i], f_list[i], Zxx_list[i])
        if feat is not None:
            feat['shot_id'] = shot_ids[i]
            feat['contour_idx'] = j       
            features.append(feat)

    all_feature_dfs.extend(features)

# Concatenate into a single DataFrame
df_all = pd.DataFrame(all_feature_dfs)
df_all.head()
duration freq_span slope avg_amp max_amp start_time end_time start_freq end_freq length shot_id contour_idx
0 0.004 0.000 0.000000 0.023028 0.024961 0.116000 0.120000 0.0 0.0 4 23447 0
1 0.008 0.000 0.000000 0.025997 0.033559 0.178000 0.186000 0.0 0.0 6 23447 1
2 0.004 0.000 0.000000 0.021432 0.024144 0.198000 0.202000 0.0 0.0 4 23447 2
3 0.008 0.000 0.000000 0.021969 0.025508 0.208000 0.216000 0.0 0.0 6 23447 3
4 0.048 8999.982 -85343.955436 0.012600 0.110189 0.404001 0.452001 0.0 0.0 148 23447 4
df_clean = df_all.dropna()

###### More filtering needed. A lot of garbage contour still left #######
df_clean = df_clean[(df_clean['duration'] > 0.004) & 
                    (df_clean['avg_amp'] > 1e-4) &
                    (df_clean['length'] >= 5)&
                    (df_clean['freq_span'] > 0.1) &
                    (np.abs(df_clean['slope']) < 5e4)]

# select all rows with shot_id 23447
df_clean[df_clean['shot_id'] == 23447].head()
duration freq_span slope avg_amp max_amp start_time end_time start_freq end_freq length shot_id contour_idx
7 0.064 5749.9885 30710.906079 0.023135 0.156930 0.352001 0.416001 20499.9590 20499.9590 145 23447 7
8 0.004 499.9990 -23026.223684 0.030197 0.046078 0.434001 0.438001 15499.9690 15499.9690 9 23447 8
11 0.010 2249.9955 -5161.559246 0.013658 0.051164 0.332001 0.342001 25249.9495 25249.9495 29 23447 11
19 0.050 4749.9905 43065.864199 0.021728 0.149185 0.364001 0.414001 39499.9210 39499.9210 121 23447 19
21 0.004 749.9985 4166.650000 0.031503 0.057582 0.334001 0.338001 38249.9235 38249.9235 11 23447 21
len(df_clean)
261
df_clean.describe()
duration freq_span slope avg_amp max_amp start_time end_time start_freq end_freq length shot_id contour_idx
count 261.000000 261.000000 261.000000 261.000000 261.000000 261.000000 261.000000 261.000000 261.000000 261.000000 261.000000 261.00000
mean 0.006690 1160.917218 -864.429527 0.022488 0.049807 0.345472 0.352162 57955.822785 57955.822785 17.896552 29098.823755 78.94636
std 0.005537 1244.587927 19341.970055 0.006750 0.015256 0.091565 0.091794 34317.698029 34317.698029 17.018113 2414.595628 52.03258
min 0.004000 249.999500 -48348.863208 0.005850 0.024897 0.052000 0.068000 0.000000 0.000000 7.000000 23447.000000 0.00000
25% 0.004000 499.999000 -10445.163699 0.018000 0.039058 0.294001 0.302001 25999.948000 25999.948000 9.000000 30005.000000 38.00000
50% 0.006000 749.998500 -3348.200893 0.022258 0.049379 0.320001 0.324001 63749.872500 63749.872500 13.000000 30021.000000 68.00000
75% 0.008000 1249.997500 8928.535714 0.026734 0.057717 0.380001 0.384001 87999.824000 87999.824000 19.000000 30021.000000 108.00000
max 0.064000 7749.984500 48275.668966 0.038196 0.156930 0.660001 0.668001 138749.722500 138749.722500 145.000000 30421.000000 209.00000

PCA#

features_for_pca = df_clean[[
    'duration',
    'freq_span',
    'slope',
    'avg_amp',
    'max_amp',
    'length',
    'start_time',
    'end_time',
    'start_freq',
    'end_freq'
    # add 'start_time', 'end_time', etc. if you want temporal position in clustering
]]
X_scaled = StandardScaler().fit_transform(features_for_pca)
pca = PCA(n_components=2)
X_pca = pca.fit_transform(X_scaled)
plt.figure(figsize=(10, 6))
plt.scatter(X_pca[:, 0], X_pca[:, 1], alpha=0.5)
plt.xlabel("PCA 1")
plt.ylabel("PCA 2")
plt.title("Contour Feature Space (PCA)")
plt.grid(True)
../../_images/1cfca5d15571ce9ff511f4b9c9b85b26c1cadb166bd4333e42c63e8c42dfb67d.png
pca.components_  # Rows are PCs, columns are original features
array([[-0.30448631, -0.33932833, -0.04914475,  0.2417678 , -0.2029942 ,
        -0.35094409, -0.34756518, -0.36506315,  0.39549744,  0.39549744],
       [ 0.37581452,  0.34228755,  0.11716779, -0.03324786,  0.36578401,
         0.43375828, -0.36790736, -0.34431705,  0.27469504,  0.27469504]])

Simple clustering#

#from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
n_clusters = 8
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
cluster_labels = kmeans.fit_predict(X_pca)
plt.figure(figsize=(8, 6))
for i in range(n_clusters):
    mask = cluster_labels == i
    plt.scatter(X_pca[mask, 0], X_pca[mask, 1], label=f"Cluster {i}", alpha=0.7)

plt.xlabel("PCA 1")
plt.ylabel("PCA 2")
plt.title("Contour Clustering via PCA + KMeans")
plt.legend()
plt.grid(True)
plt.tight_layout()
../../_images/e0477545da2c6ad698ea21b51b64051a50573a6c08df61005f99d914a269a70a.png
df_clean['cluster'] = cluster_labels
num_clusters = df_clean['cluster'].nunique()
colors = cm.get_cmap('tab10', num_clusters)

for shot_id in df_clean['shot_id'].unique():
    # Pull the STFT and metadata for this shot
    idx = shot_ids.index(shot_id)
    #Z = ds_list[idx]
    Z = np.abs(Zxx_list[idx]) 
    t = t_list[idx]
    f = f_list[idx]

    # Get contours and cluster labels for this shot
    df_shot = df_clean[df_clean['shot_id'] == shot_id]
    contours = measure.find_contours(np.isfinite(seg_list[idx]).astype(float), 0.5)

    plt.figure(figsize=(12, 5))
    plt.imshow(
        np.abs(Zxx_list[idx]),
        aspect='auto', origin='lower',
        extent=[t[0], t[-1], f[0], f[-1]],
        cmap='inferno', norm=LogNorm(vmin=1e-5, vmax=1e-1)
    )
    # for _, row in df_shot.iterrows():
    #     # Reconstruct a line segment if you still have contours
    #     #c = measure.find_contours(np.isfinite(seg_list[idx]).astype(float), 0.5)[int(row.name)]
    #     c = contours[int(row['contour_idx'])] 
    #     plt.plot(t[c[:, 1].astype(int)], f[c[:, 0].astype(int)], 
    #              color=colors(row['cluster']), linewidth=2)

    for _, row in df_shot.iterrows():
        contour = contours[int(row['contour_idx'])]     # this links to clustering
        t_pts = t[contour[:, 1].astype(int)]
        f_pts = f[contour[:, 0].astype(int)]
        plt.plot(t_pts, f_pts, color=colors(int(row['cluster'])), linewidth=2)

    plt.colorbar(label='Amplitude')
    plt.title(f"Clustered Contours - Shot {shot_id}")
    plt.xlabel("Time [s]")
    plt.ylabel("Frequency [Hz]")
    plt.ylim(0, 50000)
    plt.tight_layout()
    plt.show()
/tmp/ipykernel_13399/3660331428.py:2: MatplotlibDeprecationWarning: The get_cmap function was deprecated in Matplotlib 3.7 and will be removed in 3.11. Use ``matplotlib.colormaps[name]`` or ``matplotlib.colormaps.get_cmap()`` or ``pyplot.get_cmap()`` instead.
  colors = cm.get_cmap('tab10', num_clusters)
../../_images/c2ed9841ebc5cc5686c1cb453df522617752c5c03c9512e0ab94b90b2f7c8bd9.png ../../_images/08848fd3e87e8c327447cc21dc41c9f78bdab76d17a13435726d0358107cae40.png ../../_images/7f3d7cc2ef64dd2b2a29789bfd99cf4ec195fd4062452d63e7d0c0aa2bd14cc9.png ../../_images/e7e58830067238650e2dd36cc07d5472f818c87009fc9a4f1b40883a2473bad5.png