XAI for Sea Ice Classification of Waveform Characteristics#

This notebook demonstrates techniques for classifying sea ice and leads from waveform data using both deep learning and tree-based methods and the ways to show which waveform characteristics contributes the most to the classification.

Models Implemented#

  • CNN: Captures patterns directly from raw waveforms

  • Random Forest: Robust ensemble approach with inherent feature importance

  • Gradient Boosting: Sequential learning focused on misclassification correction

  • XGBoost: Optimised gradient boosting with regularisation

Key Analysis Techniques#

  • Preprocessing of waveform characteristics

  • SHAP values for CNN interpretability

  • Model comparison with confusion matrices

  • Identification of consensus important regions across models

The notebook identifies which waveform characteristics are most discriminative for sea ice classification, offering insights that may improve for waveform classification.

1. Data Preprocessing#

  • Reads waveform data from a CSV file and processes it into structured arrays.

  • Handles missing values and class label inconsistencies.

  • Splits data into training and test sets while scaling features for consistency.


2. CNN Model for Waveform Classification#

  • Implements a Convolutional Neural Network (CNN) to classify sea ice and leads from radar waveform characteristics.

  • Uses multiple convolutional layers with batch normalisation, max pooling, dropout, and dense layers.

  • Compiles with the Adam optimizer and binary cross-entropy loss, monitoring performance with accuracy and AUC metrics.

  • Trains with early stopping and learning rate reduction strategies to prevent overfitting.

  • Plots training history to assess model performance over epochs.


3. Model Evaluation & Explainability (SHAP & Gradients)#

  • Evaluates the trained CNN model on the test set, reporting accuracy, confusion matrix, and classification metrics.

  • Uses SHAP (SHapley Additive Explanations) values to interpret feature importance, identifying key waveform regions influencing classification.

  • Helps in understanding which waveform characteristics contribute most to sea ice classification decisions.

We need to install GPY and restart session as we did in our last notebook.

!pip install Gpy
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, roc_curve, auc
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, BatchNormalization
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
import seaborn as sns
import ast
from scipy.optimize import curve_fit
from scipy.stats import median_abs_deviation
import shap

# Function to parse string representation of waveform array into numeric values
def parse_waveform(waveform_str):
    """Parse the string representation of waveform array into numeric values"""
    cleaned_str = waveform_str.replace('e+', 'e').replace(' ', '')

    try:
        # Use ast.literal_eval for safe evaluation of the string as a list
        waveform_list = ast.literal_eval(cleaned_str)
        return np.array(waveform_list)
    except (SyntaxError, ValueError) as e:
        try:
            # Manual parsing fallback
            values_str = waveform_str.strip('[]').split(',')
            values = []
            for val in values_str:
                val = val.strip()
                if val:  # Skip empty strings
                    values.append(float(val))
            return np.array(values)
        except Exception as e2:
            print(f"Error parsing waveform: {e2}")
            return np.array([])

# ----- Waveform Feature Extraction Functions -----

def extract_waveform_features(waveform):
    """Extract features from a radar waveform as described in the literature"""
    features = {}

    # Normalize waveform for relative measurements
    if np.max(waveform) > 0:
        norm_waveform = waveform / np.max(waveform)
    else:
        norm_waveform = waveform.copy()

    # 1. Leading Edge Width (LEW)
    # Find the positions where the waveform crosses 30% and 70% thresholds on rising edge
    try:
        # Find the maximum position
        max_idx = np.argmax(norm_waveform)

        # Find indexes before the peak where waveform crosses 30% and 70% thresholds
        thresh_30_idx = None
        thresh_70_idx = None

        for i in range(max_idx):
            if norm_waveform[i] <= 0.3 and norm_waveform[i+1] > 0.3:
                # Interpolate to find more precise bin position
                frac = (0.3 - norm_waveform[i]) / (norm_waveform[i+1] - norm_waveform[i])
                thresh_30_idx = i + frac
            if norm_waveform[i] <= 0.7 and norm_waveform[i+1] > 0.7:
                # Interpolate to find more precise bin position
                frac = (0.7 - norm_waveform[i]) / (norm_waveform[i+1] - norm_waveform[i])
                thresh_70_idx = i + frac

        # If thresholds weren't found before peak, use default values
        if thresh_30_idx is None or thresh_70_idx is None:
            features['lew'] = np.nan
        else:
            features['lew'] = thresh_70_idx - thresh_30_idx
    except:
        features['lew'] = np.nan

    # 2. Waveform Maximum (Wm)
    features['wm'] = np.max(waveform)

    # 3. Trailing Edge Decline (Ted)
    try:
        # Exponential decay function for fitting
        def exp_decay(x, a, b):
            return a * np.exp(-b * x)

        max_idx = np.argmax(waveform)
        trailing_edge = waveform[max_idx:]

        if len(trailing_edge) > 3:  # Need at least a few points for fitting
            x_data = np.arange(len(trailing_edge))

            # Avoid curve_fit errors by ensuring positive values
            trailing_edge_pos = np.maximum(trailing_edge, 1e-10)

            # Initial guess for parameters
            p0 = [trailing_edge_pos[0], 0.1]

            try:
                # Fit exponential decay to trailing edge
                popt, _ = curve_fit(exp_decay, x_data, trailing_edge_pos, p0=p0, maxfev=1000)
                features['ted'] = popt[1]  # Decay rate parameter
            except:
                features['ted'] = np.nan
        else:
            features['ted'] = np.nan
    except:
        features['ted'] = np.nan

    # 4. Waveform Noise (Wn) - MAD of trailing edge residuals
    try:
        max_idx = np.argmax(waveform)
        trailing_edge = waveform[max_idx:]

        if len(trailing_edge) > 3 and 'ted' in features and not np.isnan(features['ted']):
            # Using the fitted parameters from Ted
            x_data = np.arange(len(trailing_edge))
            fitted_values = features['wm'] * np.exp(-features['ted'] * x_data)

            # Calculate residuals
            residuals = trailing_edge - fitted_values

            # MAD of residuals
            features['wn'] = median_abs_deviation(residuals, scale=1.0)
        else:
            features['wn'] = np.nan
    except:
        features['wn'] = np.nan

    # 5. Waveform Width (Ww)
    # Count bins with power > 0
    features['ww'] = np.sum(waveform > 0)

    # 6. Leading Edge Slope (Les)
    try:
        max_idx = np.argmax(norm_waveform)

        # Find bins where waveform exceeds 30% of max
        thresh_30_idx = None
        for i in range(max_idx):
            if norm_waveform[i] <= 0.3 and norm_waveform[i+1] > 0.3:
                thresh_30_idx = i
                break

        if thresh_30_idx is not None:
            # Calculate difference between max bin and first 30% threshold bin
            features['les'] = max_idx - thresh_30_idx
        else:
            features['les'] = np.nan
    except:
        features['les'] = np.nan

    # 7. Trailing Edge Slope (Tes)
    try:
        max_idx = np.argmax(norm_waveform)

        # Find last bin where waveform exceeds 30% of max
        thresh_30_idx = None
        for i in range(len(norm_waveform)-1, max_idx, -1):
            if norm_waveform[i] <= 0.3 and norm_waveform[i-1] > 0.3:
                thresh_30_idx = i
                break

        if thresh_30_idx is not None:
            # Calculate difference between last 30% threshold bin and max bin
            features['tes'] = thresh_30_idx - max_idx
        else:
            features['tes'] = np.nan
    except:
        features['tes'] = np.nan

    # 8. Pulse Peakiness (PP) - common waveform feature in literature
    try:
        features['pp'] = features['wm'] / np.mean(waveform)
    except:
        features['pp'] = np.nan

    # 9. Max position (relative bin where the maximum occurs)
    try:
        features['max_pos'] = np.argmax(waveform) / len(waveform)
    except:
        features['max_pos'] = np.nan

    return features

def extract_waveforms_and_features(roughness, target_column):
    """Process all waveforms and extract features"""
    # Lists to store data
    feature_list = []
    target_list = []
    valid_indices = []
    raw_waveforms = []

    print("Extracting features from waveforms...")

    # Process each row
    for idx, row in roughness.iterrows():
        if idx % 500 == 0:
            print(f"Processing row {idx}/{len(roughness)}...")

        try:
            # Get the waveform and class
            waveform_str = str(row['Matched_Waveform_20_Ku'])
            waveform_array = parse_waveform(waveform_str)

            # Get class label
            class_label = int(row[target_column])

            # Extract features if valid waveform
            if len(waveform_array) > 0:
                # Extract features
                features = extract_waveform_features(waveform_array)

                # Skip if any feature is NaN
                if np.any(np.isnan(list(features.values()))):
                    continue

                # Store data
                feature_list.append(list(features.values()))
                target_list.append(class_label)
                valid_indices.append(idx)
                raw_waveforms.append(waveform_array)

        except Exception as e:
            if idx < 5:  # Print first few errors
                print(f"Error processing row {idx}: {e}")
            continue

    # Convert to numpy arrays
    X_features = np.array(feature_list)
    y = np.array(target_list)
    X_raw = np.array(raw_waveforms)

    # Create feature names
    feature_names = ['Leading Edge Width', 'Waveform Maximum',
                     'Trailing Edge Decline', 'Waveform Noise',
                     'Waveform Width', 'Leading Edge Slope',
                     'Trailing Edge Slope', 'Pulse Peakiness',
                     'Max Position']

    return X_features, y, X_raw, valid_indices, feature_names

# ----- Main Processing -----
roughness = pd.read_csv('/content/drive/MyDrive/GEOL0069/2324/Week 9 2025/updated_filtered_matched_uit_sentinel3_L2_alongtrack_2023_04_official.txt')
print(roughness)
print("Starting waveform feature extraction and classification...")

# Define the target column
target_column = 'Sea_Ice_Class'
if target_column not in roughness.columns:
    print(f"Warning: {target_column} not found. Available columns: {roughness.columns}")
    print("Will use 'Lead_Class' as target instead.")
    target_column = 'Lead_Class'

# Extract features from all waveforms
X_features, y, X_raw, valid_indices, feature_names = extract_waveforms_and_features(roughness, target_column)

print(f"Extracted {len(feature_names)} features from {len(X_features)} valid waveforms")
print(f"Features: {feature_names}")
print(f"Feature data shape: {X_features.shape}")
print(f"Class distribution: {np.bincount(y)}")

# Check for NaN values
nan_mask = np.isnan(X_features).any(axis=1)
X_features_clean = X_features[~nan_mask]
y_clean = y[~nan_mask]
X_raw_clean = X_raw[~nan_mask]

print(f"Removed {np.sum(nan_mask)} samples with NaN values")
print(f"Clean feature data shape: {X_features_clean.shape}")

# Split data into training and test sets
X_train, X_test, y_train, y_test, X_raw_train, X_raw_test = train_test_split(
    X_features_clean, y_clean, X_raw_clean,
    test_size=0.2, random_state=42, stratify=y_clean
)

# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
       3A=0_3B=1  Orbit_#  Segment_#     Datenumber   Latitude  Longitude  \
0              0        1          1  738977.028790  74.432732 -73.103512   
1              0        1          1  738977.028790  74.435182 -73.109831   
2              0        1          1  738977.028791  74.437633 -73.116151   
3              0        1          1  738977.028791  74.440083 -73.122474   
4              0        1          1  738977.028792  74.442533 -73.128797   
...          ...      ...        ...            ...        ...        ...   
28720          0      843        843  739006.699439  71.763165 -74.486482   
28721          0      843        843  739006.699440  71.760574 -74.491177   
28722          0      843        843  739006.699440  71.757982 -74.495872   
28723          0      843        843  739006.699450  71.711313 -74.580159   
28724          0      843        843  739006.699451  71.708719 -74.584830   

       Radar_Freeboard  Surface_Height_WGS84  Sea_Surface_Height_Interp_WGS84  \
0            -0.044136             14.319994                        14.364129   
1            -0.099506             14.265998                        14.365504   
2             0.029215             14.396251                        14.367035   
3             0.062597             14.431502                        14.368905   
4            -0.044797             14.326136                        14.370932   
...                ...                   ...                              ...   
28720         0.362561              0.308307                        -0.054254   
28721         0.276995              0.205567                        -0.071428   
28722         0.301225              0.211470                        -0.089755   
28723         0.165803             -0.241391                        -0.407193   
28724         0.255920             -0.165441                        -0.421361   

       SSH_Uncertainty  ...  Lead_Class  Sea_Ice_Roughness  \
0             0.000254  ...           0           0.017885   
1             0.000232  ...           0           0.003315   
2             0.000212  ...           0           0.018211   
3             0.000192  ...           0           0.051418   
4             0.000173  ...           0           0.046983   
...                ...  ...         ...                ...   
28720         0.014538  ...           0           0.067003   
28721         0.014706  ...           0           0.040842   
28722         0.014876  ...           0           0.034679   
28723         0.018095  ...           0           0.443806   
28724         0.018284  ...           0           0.058687   

       Sea_Ice_Concentration  Seconds_since_2000  Year  Month  Day  \
0                     1.0000        7.336249e+08  2023      4    1   
1                     1.0000        7.336249e+08  2023      4    1   
2                     1.0000        7.336249e+08  2023      4    1   
3                     1.0000        7.336249e+08  2023      4    1   
4                     1.0000        7.336249e+08  2023      4    1   
...                      ...                 ...   ...    ...  ...   
28720                 0.9083        7.361884e+08  2023      4   30   
28721                 0.9083        7.361884e+08  2023      4   30   
28722                 0.9083        7.361884e+08  2023      4   30   
28723                 0.9083        7.361884e+08  2023      4   30   
28724                 0.9083        7.361884e+08  2023      4   30   

              Proj_X        Proj_Y  \
0      646585.001913  8.269660e+06   
1      646373.758053  8.269917e+06   
2      646162.531358  8.270174e+06   
3      645951.280450  8.270431e+06   
4      645740.085773  8.270688e+06   
...              ...           ...   
28720  622709.601630  7.969277e+06   
28721  622562.605479  7.968979e+06   
28722  622415.569925  7.968680e+06   
28723  619768.967681  7.963311e+06   
28724  619621.913813  7.963013e+06   

                                  Matched_Waveform_20_Ku  
0      [2.524e+00,2.780e+00,2.194e+00,2.286e+00,2.054...  
1      [2.381e+00,2.182e+00,2.332e+00,2.352e+00,1.857...  
2      [2.173e+00,1.961e+00,2.388e+00,2.075e+00,2.609...  
3      [2.202e+00,2.223e+00,2.049e+00,2.213e+00,2.138...  
4      [3.252e+00,2.874e+00,2.988e+00,2.789e+00,2.546...  
...                                                  ...  
28720                                                NaN  
28721                                                NaN  
28722                                                NaN  
28723                                                NaN  
28724                                                NaN  

[28725 rows x 23 columns]
Starting waveform feature extraction and classification...
Extracting features from waveforms...
Processing row 0/28725...
Processing row 500/28725...
Processing row 1000/28725...
Processing row 1500/28725...
Processing row 2000/28725...
Processing row 2500/28725...
Processing row 3000/28725...
Processing row 3500/28725...
Processing row 4000/28725...
Processing row 4500/28725...
Processing row 5000/28725...
Processing row 5500/28725...
Processing row 6000/28725...
Processing row 6500/28725...
Processing row 7000/28725...
Processing row 7500/28725...
Processing row 8000/28725...
Processing row 8500/28725...
Processing row 9000/28725...
Processing row 9500/28725...
Processing row 10000/28725...
Processing row 10500/28725...
Processing row 11000/28725...
Processing row 11500/28725...
Processing row 12000/28725...
Processing row 12500/28725...
Processing row 13000/28725...
Processing row 13500/28725...
Processing row 14000/28725...
Processing row 14500/28725...
Processing row 15000/28725...
Processing row 15500/28725...
Processing row 16000/28725...
Processing row 16500/28725...
Processing row 17000/28725...
Processing row 17500/28725...
Processing row 18000/28725...
Processing row 18500/28725...
Processing row 19000/28725...
Processing row 19500/28725...
Processing row 20000/28725...
Processing row 20500/28725...
Processing row 21000/28725...
Processing row 21500/28725...
Processing row 22000/28725...
Processing row 22500/28725...
Processing row 23000/28725...
Processing row 23500/28725...
Processing row 24000/28725...
Processing row 24500/28725...
Processing row 25000/28725...
Processing row 25500/28725...
Processing row 26000/28725...
Processing row 26500/28725...
Processing row 27000/28725...
Processing row 27500/28725...
Processing row 28000/28725...
Processing row 28500/28725...
Extracted 9 features from 12765 valid waveforms
Features: ['Leading Edge Width', 'Waveform Maximum', 'Trailing Edge Decline', 'Waveform Noise', 'Waveform Width', 'Leading Edge Slope', 'Trailing Edge Slope', 'Pulse Peakiness', 'Max Position']
Feature data shape: (12765, 9)
Class distribution: [  604 12161]
Removed 0 samples with NaN values
Clean feature data shape: (12765, 9)


# ----- Neural Network Model for Feature-Based Classification -----

# Build a simple neural network model
model = Sequential([
    Dense(16, activation='relu', input_shape=(X_train_scaled.shape[1],)),
    BatchNormalization(),
    Dropout(0.3),
    Dense(8, activation='relu'),
    Dropout(0.2),
    Dense(1, activation='sigmoid')
])

# Compile the model
model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.AUC()]
)

# Model summary
model.summary()

# Callbacks
callbacks = [
    EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=0.0001)
]

# Apply class weights if imbalanced
class_weights = None
if len(np.unique(y_train)) > 1:
    n_samples = len(y_train)
    n_classes = len(np.unique(y_train))
    class_counts = np.bincount(y_train)
    if np.min(class_counts) / np.max(class_counts) < 0.5:  # If imbalanced
        print("Detected class imbalance, applying class weights")
        class_weights = {i: n_samples / (n_classes * count) for i, count in enumerate(class_counts)}

# Train the model
print("\nTraining neural network model...")
history = model.fit(
    X_train_scaled, y_train,
    epochs=50,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks,
    class_weight=class_weights,
    verbose=1
)

# Plot training history
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# ----- Model Evaluation -----


y_pred_prob = model.predict(X_test_scaled).flatten()
y_pred = (y_pred_prob > 0.5).astype(int)

accuracy = accuracy_score(y_test, y_pred)
print(f"\nAccuracy: {accuracy:.4f}")

print("\nClassification Report:")
print(classification_report(y_test, y_pred))

conf_matrix = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Lead (0)', 'Sea Ice (1)'],
            yticklabels=['Lead (0)', 'Sea Ice (1)'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()


       3A=0_3B=1  Orbit_#  Segment_#     Datenumber   Latitude  Longitude  \
0              0        1          1  738977.028790  74.432732 -73.103512   
1              0        1          1  738977.028790  74.435182 -73.109831   
2              0        1          1  738977.028791  74.437633 -73.116151   
3              0        1          1  738977.028791  74.440083 -73.122474   
4              0        1          1  738977.028792  74.442533 -73.128797   
...          ...      ...        ...            ...        ...        ...   
28720          0      843        843  739006.699439  71.763165 -74.486482   
28721          0      843        843  739006.699440  71.760574 -74.491177   
28722          0      843        843  739006.699440  71.757982 -74.495872   
28723          0      843        843  739006.699450  71.711313 -74.580159   
28724          0      843        843  739006.699451  71.708719 -74.584830   

       Radar_Freeboard  Surface_Height_WGS84  Sea_Surface_Height_Interp_WGS84  \
0            -0.044136             14.319994                        14.364129   
1            -0.099506             14.265998                        14.365504   
2             0.029215             14.396251                        14.367035   
3             0.062597             14.431502                        14.368905   
4            -0.044797             14.326136                        14.370932   
...                ...                   ...                              ...   
28720         0.362561              0.308307                        -0.054254   
28721         0.276995              0.205567                        -0.071428   
28722         0.301225              0.211470                        -0.089755   
28723         0.165803             -0.241391                        -0.407193   
28724         0.255920             -0.165441                        -0.421361   

       SSH_Uncertainty  ...  Lead_Class  Sea_Ice_Roughness  \
0             0.000254  ...           0           0.017885   
1             0.000232  ...           0           0.003315   
2             0.000212  ...           0           0.018211   
3             0.000192  ...           0           0.051418   
4             0.000173  ...           0           0.046983   
...                ...  ...         ...                ...   
28720         0.014538  ...           0           0.067003   
28721         0.014706  ...           0           0.040842   
28722         0.014876  ...           0           0.034679   
28723         0.018095  ...           0           0.443806   
28724         0.018284  ...           0           0.058687   

       Sea_Ice_Concentration  Seconds_since_2000  Year  Month  Day  \
0                     1.0000        7.336249e+08  2023      4    1   
1                     1.0000        7.336249e+08  2023      4    1   
2                     1.0000        7.336249e+08  2023      4    1   
3                     1.0000        7.336249e+08  2023      4    1   
4                     1.0000        7.336249e+08  2023      4    1   
...                      ...                 ...   ...    ...  ...   
28720                 0.9083        7.361884e+08  2023      4   30   
28721                 0.9083        7.361884e+08  2023      4   30   
28722                 0.9083        7.361884e+08  2023      4   30   
28723                 0.9083        7.361884e+08  2023      4   30   
28724                 0.9083        7.361884e+08  2023      4   30   

              Proj_X        Proj_Y  \
0      646585.001913  8.269660e+06   
1      646373.758053  8.269917e+06   
2      646162.531358  8.270174e+06   
3      645951.280450  8.270431e+06   
4      645740.085773  8.270688e+06   
...              ...           ...   
28720  622709.601630  7.969277e+06   
28721  622562.605479  7.968979e+06   
28722  622415.569925  7.968680e+06   
28723  619768.967681  7.963311e+06   
28724  619621.913813  7.963013e+06   

                                  Matched_Waveform_20_Ku  
0      [2.524e+00,2.780e+00,2.194e+00,2.286e+00,2.054...  
1      [2.381e+00,2.182e+00,2.332e+00,2.352e+00,1.857...  
2      [2.173e+00,1.961e+00,2.388e+00,2.075e+00,2.609...  
3      [2.202e+00,2.223e+00,2.049e+00,2.213e+00,2.138...  
4      [3.252e+00,2.874e+00,2.988e+00,2.789e+00,2.546...  
...                                                  ...  
28720                                                NaN  
28721                                                NaN  
28722                                                NaN  
28723                                                NaN  
28724                                                NaN  

[28725 rows x 23 columns]
Starting waveform feature extraction and classification...
Extracting features from waveforms...
Processing row 0/28725...
Processing row 500/28725...
Processing row 1000/28725...
Processing row 1500/28725...
Processing row 2000/28725...
Processing row 2500/28725...
Processing row 3000/28725...
Processing row 3500/28725...
Processing row 4000/28725...
Processing row 4500/28725...
Processing row 5000/28725...
Processing row 5500/28725...
Processing row 6000/28725...
Processing row 6500/28725...
Processing row 7000/28725...
Processing row 7500/28725...
Processing row 8000/28725...
Processing row 8500/28725...
Processing row 9000/28725...
Processing row 9500/28725...
Processing row 10000/28725...
Processing row 10500/28725...
Processing row 11000/28725...
Processing row 11500/28725...
Processing row 12000/28725...
Processing row 12500/28725...
Processing row 13000/28725...
Processing row 13500/28725...
Processing row 14000/28725...
Processing row 14500/28725...
Processing row 15000/28725...
Processing row 15500/28725...
Processing row 16000/28725...
Processing row 16500/28725...
Processing row 17000/28725...
Processing row 17500/28725...
Processing row 18000/28725...
Processing row 18500/28725...
Processing row 19000/28725...
Processing row 19500/28725...
Processing row 20000/28725...
Processing row 20500/28725...
Processing row 21000/28725...
Processing row 21500/28725...
Processing row 22000/28725...
Processing row 22500/28725...
Processing row 23000/28725...
Processing row 23500/28725...
Processing row 24000/28725...
Processing row 24500/28725...
Processing row 25000/28725...
Processing row 25500/28725...
Processing row 26000/28725...
Processing row 26500/28725...
Processing row 27000/28725...
Processing row 27500/28725...
Processing row 28000/28725...
Processing row 28500/28725...
Extracted 9 features from 12765 valid waveforms
Features: ['Leading Edge Width', 'Waveform Maximum', 'Trailing Edge Decline', 'Waveform Noise', 'Waveform Width', 'Leading Edge Slope', 'Trailing Edge Slope', 'Pulse Peakiness', 'Max Position']
Feature data shape: (12765, 9)
Class distribution: [  604 12161]
Removed 0 samples with NaN values
Clean feature data shape: (12765, 9)
/usr/local/lib/python3.11/dist-packages/keras/src/layers/core/dense.py:87: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
Model: "sequential_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
┃ Layer (type)                          Output Shape                         Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
│ dense_4 (Dense)                      │ (None, 16)                  │             160 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ batch_normalization_1                │ (None, 16)                  │              64 │
│ (BatchNormalization)                 │                             │                 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dropout_2 (Dropout)                  │ (None, 16)                  │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense_5 (Dense)                      │ (None, 8)                   │             136 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dropout_3 (Dropout)                  │ (None, 8)                   │               0 │
├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤
│ dense_6 (Dense)                      │ (None, 1)                   │               9 │
└──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
 Total params: 369 (1.44 KB)
 Trainable params: 337 (1.32 KB)
 Non-trainable params: 32 (128.00 B)
Detected class imbalance, applying class weights

Training neural network model...
Epoch 1/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 7s 14ms/step - accuracy: 0.5987 - auc_1: 0.5160 - loss: 1.0171 - val_accuracy: 0.5859 - val_auc_1: 0.9090 - val_loss: 0.5763 - learning_rate: 0.0010
Epoch 2/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - accuracy: 0.6790 - auc_1: 0.8197 - loss: 0.5255 - val_accuracy: 0.6721 - val_auc_1: 0.9266 - val_loss: 0.4585 - learning_rate: 0.0010
Epoch 3/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7154 - auc_1: 0.8620 - loss: 0.4470 - val_accuracy: 0.7367 - val_auc_1: 0.9282 - val_loss: 0.3785 - learning_rate: 0.0010
Epoch 4/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - accuracy: 0.7772 - auc_1: 0.8951 - loss: 0.4121 - val_accuracy: 0.7288 - val_auc_1: 0.9296 - val_loss: 0.3542 - learning_rate: 0.0010
Epoch 5/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7839 - auc_1: 0.9084 - loss: 0.3748 - val_accuracy: 0.7239 - val_auc_1: 0.9302 - val_loss: 0.3491 - learning_rate: 0.0010
Epoch 6/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - accuracy: 0.7602 - auc_1: 0.9008 - loss: 0.3841 - val_accuracy: 0.7024 - val_auc_1: 0.9261 - val_loss: 0.3546 - learning_rate: 0.0010
Epoch 7/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7219 - auc_1: 0.9073 - loss: 0.3571 - val_accuracy: 0.6960 - val_auc_1: 0.9287 - val_loss: 0.3392 - learning_rate: 0.0010
Epoch 8/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7216 - auc_1: 0.9023 - loss: 0.3778 - val_accuracy: 0.6985 - val_auc_1: 0.9304 - val_loss: 0.3253 - learning_rate: 0.0010
Epoch 9/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7148 - auc_1: 0.9115 - loss: 0.3437 - val_accuracy: 0.6931 - val_auc_1: 0.9301 - val_loss: 0.3325 - learning_rate: 0.0010
Epoch 10/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - accuracy: 0.7191 - auc_1: 0.9179 - loss: 0.3579 - val_accuracy: 0.7000 - val_auc_1: 0.9325 - val_loss: 0.3224 - learning_rate: 0.0010
Epoch 11/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7230 - auc_1: 0.9275 - loss: 0.3244 - val_accuracy: 0.7073 - val_auc_1: 0.9335 - val_loss: 0.3016 - learning_rate: 0.0010
Epoch 12/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - accuracy: 0.7053 - auc_1: 0.9204 - loss: 0.3358 - val_accuracy: 0.6902 - val_auc_1: 0.9337 - val_loss: 0.3186 - learning_rate: 0.0010
Epoch 13/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7208 - auc_1: 0.9276 - loss: 0.3272 - val_accuracy: 0.6882 - val_auc_1: 0.9304 - val_loss: 0.3309 - learning_rate: 0.0010
Epoch 14/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7141 - auc_1: 0.9289 - loss: 0.3243 - val_accuracy: 0.7044 - val_auc_1: 0.9376 - val_loss: 0.3119 - learning_rate: 0.0010
Epoch 15/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7205 - auc_1: 0.9120 - loss: 0.3443 - val_accuracy: 0.6843 - val_auc_1: 0.9323 - val_loss: 0.3252 - learning_rate: 0.0010
Epoch 16/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7030 - auc_1: 0.9160 - loss: 0.3447 - val_accuracy: 0.6872 - val_auc_1: 0.9353 - val_loss: 0.3182 - learning_rate: 0.0010
Epoch 17/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7090 - auc_1: 0.9258 - loss: 0.3164 - val_accuracy: 0.6970 - val_auc_1: 0.9358 - val_loss: 0.3050 - learning_rate: 5.0000e-04
Epoch 18/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7129 - auc_1: 0.9254 - loss: 0.3153 - val_accuracy: 0.6853 - val_auc_1: 0.9327 - val_loss: 0.3222 - learning_rate: 5.0000e-04
Epoch 19/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - accuracy: 0.7125 - auc_1: 0.9290 - loss: 0.3093 - val_accuracy: 0.6936 - val_auc_1: 0.9363 - val_loss: 0.3136 - learning_rate: 5.0000e-04
Epoch 20/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7094 - auc_1: 0.9136 - loss: 0.3430 - val_accuracy: 0.6911 - val_auc_1: 0.9364 - val_loss: 0.3267 - learning_rate: 5.0000e-04
Epoch 21/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - accuracy: 0.7106 - auc_1: 0.9217 - loss: 0.3138 - val_accuracy: 0.6975 - val_auc_1: 0.9386 - val_loss: 0.3131 - learning_rate: 5.0000e-04
Epoch 22/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - accuracy: 0.7096 - auc_1: 0.9174 - loss: 0.3282 - val_accuracy: 0.6907 - val_auc_1: 0.9359 - val_loss: 0.3247 - learning_rate: 2.5000e-04
Epoch 23/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - accuracy: 0.7119 - auc_1: 0.9241 - loss: 0.3325 - val_accuracy: 0.6902 - val_auc_1: 0.9373 - val_loss: 0.3151 - learning_rate: 2.5000e-04
Epoch 24/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - accuracy: 0.7017 - auc_1: 0.9295 - loss: 0.3036 - val_accuracy: 0.6911 - val_auc_1: 0.9370 - val_loss: 0.3170 - learning_rate: 2.5000e-04
Epoch 25/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7080 - auc_1: 0.9372 - loss: 0.2912 - val_accuracy: 0.6980 - val_auc_1: 0.9383 - val_loss: 0.3081 - learning_rate: 2.5000e-04
Epoch 26/50
256/256 ━━━━━━━━━━━━━━━━━━━━ 1s 4ms/step - accuracy: 0.7060 - auc_1: 0.9212 - loss: 0.3306 - val_accuracy: 0.6911 - val_auc_1: 0.9385 - val_loss: 0.3177 - learning_rate: 2.5000e-04
_images/95140114ed0756c40e13ace88cad96975bec29030189b1336554f24085289ab4.png
80/80 ━━━━━━━━━━━━━━━━━━━━ 0s 3ms/step

Accuracy: 0.7086

Classification Report:
              precision    recall  f1-score   support

           0       0.14      0.98      0.24       121
           1       1.00      0.69      0.82      2432

    accuracy                           0.71      2553
   macro avg       0.57      0.84      0.53      2553
weighted avg       0.96      0.71      0.79      2553
_images/4b94b005bb41835eed15c05b171f6241bea494b28813c42d4a7dbea3fde07162.png _images/ed4c4c4ed9c524a16d2ec9b1c58d652fda30f5a10777851f70121c6f87731645.png
# ----- Feature Importance Analysis -----

# Use SHAP for feature importance
print("\nCalculating feature importance using SHAP values...")

try:
    # Create an explainer for the model
    explainer = shap.Explainer(model, X_train_scaled)
    shap_values = explainer(X_test_scaled)

    # Get mean absolute SHAP values for feature importance
    feature_importance = np.abs(shap_values.values).mean(0)

    # Create a DataFrame for better visualization
    importance_df = pd.DataFrame({
        'Feature': feature_names,
        'Importance': feature_importance
    }).sort_values('Importance', ascending=False)

    print("\nFeature Importance Ranking:")
    for idx, row in importance_df.iterrows():
        print(f"{row['Feature']}: {row['Importance']:.4f}")

    # Plot feature importance
    plt.figure(figsize=(10, 6))
    plt.barh(importance_df['Feature'], importance_df['Importance'], color='skyblue')
    plt.xlabel('Mean |SHAP Value|')
    plt.ylabel('Feature')
    plt.title('Feature Importance (SHAP Values)')
    plt.tight_layout()
    plt.show()

    # Plot SHAP summary plot (this shows direction of impact too)
    plt.figure(figsize=(12, 8))
    shap.summary_plot(shap_values, X_test_scaled, feature_names=feature_names, show=False)
    plt.title('SHAP Summary Plot')
    plt.tight_layout()
    plt.show()



except Exception as e:
    print(f"Error calculating SHAP values: {e}")

    # Simple feature importance from model weights (for dense layers)
    # This is a simpler approach that works when SHAP fails
    weights = model.layers[0].get_weights()[0]
    importance = np.abs(weights).mean(axis=1)

    # Create a DataFrame for better visualization
    importance_df = pd.DataFrame({
        'Feature': feature_names,
        'Importance': importance
    }).sort_values('Importance', ascending=False)

    print("\nFeature Importance Ranking (from model weights):")
    for idx, row in importance_df.iterrows():
        print(f"{row['Feature']}: {row['Importance']:.4f}")

    # Plot feature importance
    plt.figure(figsize=(10, 6))
    plt.barh(importance_df['Feature'], importance_df['Importance'], color='skyblue')
    plt.xlabel('Absolute Weight Magnitude')
    plt.ylabel('Feature')
    plt.title('Feature Importance (Neural Network Weights)')
    plt.tight_layout()
    plt.show()

print("Waveform feature-based classification complete!")
Calculating feature importance using SHAP values...
ExactExplainer explainer: 2554it [00:45, 46.92it/s]                          
Feature Importance Ranking:
Leading Edge Width: 0.0947
Pulse Peakiness: 0.0826
Leading Edge Slope: 0.0653
Trailing Edge Slope: 0.0598
Waveform Maximum: 0.0378
Trailing Edge Decline: 0.0332
Waveform Noise: 0.0253
Max Position: 0.0145
Waveform Width: 0.0000
_images/5bf9e08b81e5ebd985bc5255df6a5350ea0308487ce55cd967fe34d4d47cfc4e.png _images/2e87c7820fd1a1eeccde921d309dae96ea9f9c14627b2c418e4f0ab1befe35cf.png
Waveform feature-based classification complete!

4. Tree-Based Classification Models with Engineered Waveform Features#

Random Forest#

  • Trains a Random Forest Classifier with 100 trees and class balancing on extracted waveform features.

  • Evaluates model performance using accuracy and classification metrics.

  • Ranks feature importance to identify the most influential waveform characteristics (e.g., Pulse Peakiness, Leading Edge Width).

Gradient Boosting#

  • Implements a Gradient Boosting Classifier on waveform features, optimising misclassified samples.

  • Evaluates model performance using accuracy and classification metrics.

  • Ranks feature importance to identify the most influential waveform characteristics (e.g., Pulse Peakiness, Leading Edge Width).

XGBoost#

  • Trains an XGBoost Classifier using engineered waveform features, leveraging regularised gradient boosting.

  • Evaluates model performance using accuracy and classification metrics.

  • Ranks feature importance to identify the most influential waveform characteristics (e.g., Pulse Peakiness, Leading Edge Width).

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_curve, auc
import seaborn as sns

# Function to plot feature importance overlaid on average waveforms
def plot_importance_on_waveform(feature_importance, model_name, X_clean, y_clean):
    plt.figure(figsize=(15, 6))

    # Plot average waveforms
    waveforms_class0 = X_clean[y_clean == 0]
    waveforms_class1 = X_clean[y_clean == 1]

    if len(waveforms_class0) > 0:
        plt.plot(np.mean(waveforms_class0, axis=0), 'b-', label='Lead (Class 0)', alpha=0.5)
    if len(waveforms_class1) > 0:
        plt.plot(np.mean(waveforms_class1, axis=0), 'r-', label='Sea Ice (Class 1)', alpha=0.5)

    # Scale importance for visualization
    max_wave_amp = max(
        np.max(np.mean(waveforms_class0, axis=0)) if len(waveforms_class0) > 0 else 0,
        np.max(np.mean(waveforms_class1, axis=0)) if len(waveforms_class1) > 0 else 0
    )

    if max(feature_importance) > 0:  # Avoid division by zero
        importance_scaling = max_wave_amp / max(feature_importance) * 2
        plt.bar(range(len(feature_importance)),
                feature_importance * importance_scaling,
                alpha=0.3,
                color='g',
                label='Feature Importance')

    plt.title(f'{model_name}: Feature Importance vs. Waveform Patterns')
    plt.xlabel('Sample Index')
    plt.ylabel('Amplitude / Importance')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()
from sklearn.ensemble import RandomForestClassifier

print("=== Training Random Forest Classifier with Waveform Features ===")
rf_model = RandomForestClassifier(
    n_estimators=100,
    max_depth=20,
    min_samples_split=5,
    min_samples_leaf=2,
    random_state=42,
    class_weight='balanced'
)

# Train the model on extracted features
rf_model.fit(X_train_scaled, y_train)

# Make predictions
y_pred_rf = rf_model.predict(X_test_scaled)
y_pred_prob_rf = rf_model.predict_proba(X_test_scaled)[:, 1]

# Evaluate
accuracy_rf = accuracy_score(y_test, y_pred_rf)
print(f"Random Forest Accuracy: {accuracy_rf:.4f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred_rf))

# Get feature importance
feature_importance_rf = rf_model.feature_importances_

# Create feature importance visualization
plt.figure(figsize=(10, 6))
importance_df = pd.DataFrame({
    'Feature': feature_names,
    'Importance': feature_importance_rf
}).sort_values('Importance', ascending=False)

plt.barh(importance_df['Feature'], importance_df['Importance'], color='skyblue')
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.title('Random Forest Feature Importance')
plt.tight_layout()
plt.show()
=== Training Random Forest Classifier with Waveform Features ===
Random Forest Accuracy: 0.9749

Classification Report:
              precision    recall  f1-score   support

           0       0.80      0.63      0.70       121
           1       0.98      0.99      0.99      2432

    accuracy                           0.97      2553
   macro avg       0.89      0.81      0.85      2553
weighted avg       0.97      0.97      0.97      2553
_images/dc78d0777089ea1b996ea7ed8e19242fea2b20a32ceb3a78a7f2ba99cbff9654.png
from sklearn.ensemble import GradientBoostingClassifier

print("=== Training Gradient Boosting Classifier with Waveform Features ===")
gb_model = GradientBoostingClassifier(
    n_estimators=100,
    learning_rate=0.1,
    max_depth=5,
    min_samples_split=10,
    random_state=42
)

# Train the model on extracted features
gb_model.fit(X_train_scaled, y_train)

# Make predictions
y_pred_gb = gb_model.predict(X_test_scaled)
y_pred_prob_gb = gb_model.predict_proba(X_test_scaled)[:, 1]

# Evaluate
accuracy_gb = accuracy_score(y_test, y_pred_gb)
print(f"Gradient Boosting Accuracy: {accuracy_gb:.4f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred_gb))

# Get feature importance
feature_importance_gb = gb_model.feature_importances_

# Create feature importance visualization
plt.figure(figsize=(10, 6))
importance_df = pd.DataFrame({
    'Feature': feature_names,
    'Importance': feature_importance_gb
}).sort_values('Importance', ascending=False)

plt.barh(importance_df['Feature'], importance_df['Importance'], color='green')
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.title('Gradient Boosting Feature Importance')
plt.tight_layout()
plt.show()
=== Training Gradient Boosting Classifier with Waveform Features ===
Gradient Boosting Accuracy: 0.9738

Classification Report:
              precision    recall  f1-score   support

           0       0.80      0.60      0.68       121
           1       0.98      0.99      0.99      2432

    accuracy                           0.97      2553
   macro avg       0.89      0.79      0.83      2553
weighted avg       0.97      0.97      0.97      2553
_images/645e2d8ebf2cec603eb64cf188a230785eb7dcefec5bf817b2877b193df714ca.png
import xgboost as xgb

print("=== Training XGBoost Classifier with Waveform Features ===")
xgb_model = xgb.XGBClassifier(
    n_estimators=100,
    max_depth=5,
    learning_rate=0.1,
    use_label_encoder=False,
    eval_metric='logloss',
    random_state=42
)

# Train the model on extracted features
xgb_model.fit(X_train_scaled, y_train)

# Make predictions
y_pred_xgb = xgb_model.predict(X_test_scaled)
y_pred_prob_xgb = xgb_model.predict_proba(X_test_scaled)[:, 1]

# Evaluate
accuracy_xgb = accuracy_score(y_test, y_pred_xgb)
print(f"XGBoost Accuracy: {accuracy_xgb:.4f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred_xgb))

# Get feature importance
feature_importance_xgb = xgb_model.feature_importances_

# Create feature importance visualization
plt.figure(figsize=(10, 6))
importance_df = pd.DataFrame({
    'Feature': feature_names,
    'Importance': feature_importance_xgb
}).sort_values('Importance', ascending=False)

plt.barh(importance_df['Feature'], importance_df['Importance'], color='red')
plt.xlabel('Importance')
plt.ylabel('Feature')
plt.title('XGBoost Feature Importance')
plt.tight_layout()
plt.show()
=== Training XGBoost Classifier with Waveform Features ===
/usr/local/lib/python3.11/dist-packages/xgboost/core.py:158: UserWarning: [16:53:31] WARNING: /workspace/src/learner.cc:740: 
Parameters: { "use_label_encoder" } are not used.

  warnings.warn(smsg, UserWarning)
XGBoost Accuracy: 0.9745

Classification Report:
              precision    recall  f1-score   support

           0       0.84      0.57      0.68       121
           1       0.98      0.99      0.99      2432

    accuracy                           0.97      2553
   macro avg       0.91      0.78      0.83      2553
weighted avg       0.97      0.97      0.97      2553
_images/519c3d1ecf16480ea524fe011f0a165a42b9c42b878cd17430cb163bf8c484ee.png