Explainable AI#
Week 9 materials can be accessed here.
In week 9, we will focus a little more on how to interpret some of the models we’ve covered. EXplainable AI is a major purpose of an AI algorithm because we are always tring to understand what it’s doing and not using them as a blackbox.
Why is it important?#
Explainable AI (XAI) or interpretable AI is crucial for the following reasons.
Enhancing Trust and Confidence#
Transparency: When users and stakeholders can understand how a model makes its decisions, it builds trust in the technology. This is especially important in critical applications where the stakes are high. If we want to deploy a model to detect lead for a full month, we would like to know how it actually accomplish what it has done as much as possible, that may provide more insights into what what other problems the model is able to deal with.
Accountability: Explainability facilitates accountability by making it possible to trace the decision-making process. This is essential for identifying and correcting biases, providing more insights into some scientific reasons why the models work and how we implement it further.
Improving Model Performance and Reliability#
Debugging and Improvement: Interpretability also helps developers and data scientists identify errors or biases in models. Understanding why a model makes certain decisions allows for targeted improvements, leading to more accurate and robust systems.
Feature Importance: By understanding which features contribute most to the model’s predictions, researchers can focus on collecting and processing the most relevant data, potentially reducing costs and increasing model efficiency.
In this week, we will be focusing on feature importance, getting to know what part of the data contributes to the model inference the most.
Understanding Feature Importance#
Before diving into specific interpretability methods like feature importance in Random Forest and sensitivity analysis in Convolutional Neural Networks (CNNs), it’s crucial to establish a foundational understanding of feature importance within the broader context of interpretable machine learning. This prerequisite knowledge will set the stage for a deeper exploration of how different models provide insights into the significance of input features.
The Concept of Feature Importance#
Feature importance is a technique used to identify which input features have the most influence on a model’s predictions. This concept is pivotal in both developing and interpreting machine learning models, as it helps us understand the data’s underlying structure and the model’s decision-making process.
Intrinsic vs. Post Hoc Interpretability#
Intrinsic Interpretability: Some models, like decision trees in a Random Forest, naturally offer insights into feature importance due to their transparent structure. Here, the importance is derived directly from the model itself, without the need for additional analysis tools.
Post Hoc Interpretability: For more complex models, such as CNNs, post hoc methods like sensitivity analysis are employed to interpret the model’s behavior. These techniques analyze the model’s output in response to changes in input features, shedding light on feature importance even when the model’s internal workings are not directly interpretable.
Importance for Model Understanding and Optimization#
Understanding feature importance is not merely academic; it has practical applications in model optimization, data collection strategies, and ultimately, in making models more transparent and trustworthy.
Random Forest: By examining feature importance, we can understand which criteria the ensemble of trees uses to make decisions, guiding us in model refinement and data preprocessing.
CNN Sensitivity Analysis: Sensitivity analysis reveals how changes in input image pixels (or bands in multi-band images) affect the model’s confidence in its predictions. This insight can direct attention to the most relevant parts of the data, informing feature engineering and network architecture adjustments.
Preparing for Specific Interpretability Methods#
As we approach the topics of feature importance in Random Forest and sensitivity analysis in CNNs, it’s essential to appreciate the versatility and applicability of feature importance across different model types. Whether through the intrinsic interpretability of simpler models or the post hoc analysis of complex networks, understanding which features significantly impact model predictions is a key step towards achieving transparency, fairness, and effectiveness in machine learning applications.
Random Forest#
In random forest, we can determine the importance of each band in the classification process. Random forest, like many other tree-based models (decision trees, etc), has the capability to compute feature (in our cases, it may be the spectral bands of the imagery data) importance that gives us an indication of how useful each band is for making the classification decision.
Understanding Feature Importance in Random Forest#
Random Forest is a powerful ensemble learning method that operates by constructing a multitude of decision trees during training time and outputting the class that is the mode of the classes (classification) or mean prediction (regression) of the individual trees. A key strength of Random Forest is its ability to handle high-dimensional data and its provision of intuitive metrics for understanding which features contribute most to the predictive accuracy of the model.
How Does Random Forest Compute Feature Importance?#
The concept of feature importance in Random Forest emerges naturally from how the trees are constructed. Each tree in the forest makes decisions by splitting nodes based on the value of one or more features. The decision to split at each node is made according to a criterion that measures the “improvement” a given split brings to the purity of the node (e.g., Gini impurity for classification tasks).
Feature importance in Random Forest is calculated based on how much each feature contributes to this improvement across all trees in the forest. In essence, the more a feature decreases the impurity of the tree, the more important that feature is considered to be. This is quantified in a metric often referred to as “Gini importance” or “mean decrease in impurity” (MDI).
Interpreting Feature Importance#
After training a Random Forest model, each feature is assigned an importance score, which can be normalized to sum up to one. These scores provide a ranked list of features according to their importance:
High Importance: Features that frequently contribute to improving node purity across many trees. Such features are critical for the model’s predictions and often represent key variables that define the classification or regression problem.
Low Importance: Features that contribute little to node purity. These features have minimal impact on the model’s decision-making process and could potentially be removed without significant loss of model accuracy.
Practical Application: Analyzing Band Importance in Multi-band Images#
In the context of multi-band image classification, feature importance can be leveraged to identify which spectral bands are most valuable for discriminating between classes. This insight is invaluable for tasks like satellite image analysis, where understanding which bands (e.g., visible, infrared) are most informative can guide data collection and preprocessing strategies.
from sklearn.ensemble import RandomForestClassifier
import time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
X_train = X_train[:, 1:2, 1:2, :].reshape(X_train.shape[0], -1) ## extract the central pixel
feature_names = [f"feature {i}" for i in range(X_train.shape[1])]
forest = RandomForestClassifier(random_state=0)
X_reshaped = np.reshape(X_train, (X_train.shape[0], -1))
forest.fit(X_reshaped, y_train)
importances = forest.feature_importances_
std = np.std([tree.feature_importances_ for tree in forest.estimators_], axis=0)
forest_importances = pd.Series(importances, index=feature_names)
fig, ax = plt.subplots()
forest_importances.plot.bar(yerr=std, ax=ax)
ax.set_title("Feature importances using MDI")
ax.set_ylabel("Mean decrease in impurity")
fig.tight_layout()
CNNs (Convolutional Neural Networks) and Sensitivity Analysis#
Convolutional Neural Networks (CNNs) are pivotal in the field of image processing and classification, leveraging spatial hierarchies of features. A pressing question in utilizing CNNs revolves around identifying the most influential parts of the input data (e.g., specific bands in a multi-band image) on the model’s predictions. Sensitivity analysis emerges as a key technique for revealing the importance of input features in the model’s decision-making process. Despite the complexity of CNNs, sensitivity analysis provides a feasible approach for interpretation.
Understanding Sensitivity Analysis in CNNs#
Sensitivity analysis quantifies the effect of minor changes in the input image on the output predictions. This process involves calculating the gradient of the output with respect to the input, exploring the question: “How does a small alteration in each input pixel (or band) modify the predicted class scores?”
Mathematical Foundation of Sensitivity Analysis#
For a CNN model function \(f(x)\) mapping an input image \(x\) to output predictions, the sensitivity of the output relative to an input pixel (or band) is mathematically denoted by the partial derivative \(\frac{\partial f}{\partial x_i}\), where \(x_i\) signifies the specific pixel or band.
Implementing Sensitivity Analysis: Steps#
Image and Prediction Class Selection: For an input image \(x\) and a target class \(c\) (typically the class with the highest model prediction score), the objective is to determine how sensitive the prediction is to variations in \(x\).
Gradient Computation: Utilize TensorFlow’s
GradientTape
to compute the gradient of the class score \(f_c(x)\) concerning the input image \(x\), denoted as \(\nabla_x f_c(x)\). This gradient vector comprises the partial derivatives \(\frac{\partial f_c}{\partial x_i}\) for each pixel or band \(x_i\), indicating the impact of minor changes in \(x_i\) on the score \(f_c(x)\).Gradient Magnitude Analysis: The magnitude of these gradients, \(|\nabla_x f_c(x)|\), illustrates the sensitivity of the prediction to each segment of the input image. Areas with higher magnitudes suggest greater sensitivity, implying those input features significantly influence the model’s decision-making process.
Effectiveness of This Approach#
The efficacy of sensitivity analysis is rooted in its calculus foundation, specifically the derivative concept. By calculating the change in the output prediction for a certain class due to infinitesimal variations in input features, we achieve a direct measure of each feature’s influence on the model’s decision.
Intuitive Interpretation: Identifying features with high sensitivity indicates they possess crucial information for the classification task. For instance, specific spectral bands might be vital for distinguishing particular objects or patterns in satellite imagery.
Visualization and Insights: Visualizing prediction sensitivity across different bands offers an intuitive view of feature importance, informing model architecture optimization and data collection strategies by prioritizing the most informative features.
Conclusion#
Sensitivity analysis enables partial elucidation of CNNs’ complex operations, providing insights into the significance of input features, including spectral bands in multi-band images. This methodology not only enhances our understanding of deep learning models but also guides practical decisions in model development and data strategy.
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
import matplotlib.pyplot as plt
# Define the CNN model architecture as a function for reusability
def create_model(input_shape):
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape, padding='SAME'),
layers.MaxPooling2D((2, 2)),
layers.Flatten(),
layers.Dense(64, activation='relu'),
layers.Dense(2, activation='softmax') # Adjust the number of classes if necessary
])
return model
# Sensitivity analysis function remains the same
def sensitivity_analysis(model, input_image, class_idx):
input_image_tensor = tf.convert_to_tensor(input_image, dtype=tf.float32) # Convert to TensorFlow tensor
with tf.GradientTape() as tape:
tape.watch(input_image_tensor)
predictions = model(input_image_tensor, training=False)
class_output = predictions[:, class_idx]
gradients = tape.gradient(class_output, input_image_tensor)
band_sensitivity = tf.reduce_sum(tf.abs(gradients), axis=(1, 2))
return band_sensitivity.numpy()
# Ensemble approach: Train multiple models and compute average sensitivity
num_models = 5 # Number of models in the ensemble
ensemble_sensitivities = []
for i in range(num_models):
# Create and compile a new model instance
model = create_model(X_train.shape[1:])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# Fit the model
model.fit(X_train, y_train, epochs=10, validation_split=0.1, verbose=0) # Set verbose to 0 to reduce log output
# Perform sensitivity analysis on a sample image for each model
sample_image = X_test[:1] # Use the first image in X_test as a sample
predictions = model.predict(sample_image)
class_idx = np.argmax(predictions[0]) # Class of interest
band_sensitivity = sensitivity_analysis(model, sample_image, class_idx)
ensemble_sensitivities.append(band_sensitivity)
# Average the sensitivity scores across all models
average_sensitivity = np.mean(ensemble_sensitivities, axis=0)
# Visualize the average sensitivity scores
plt.figure(figsize=(10, 6))
plt.bar(range(X_train.shape[3]), average_sensitivity[0])
plt.xlabel('Band Number')
plt.ylabel('Average Sensitivity Score')
plt.title('Average Sensitivity of Prediction to Each Band Across Ensemble')
plt.show()
Comparing Important Bands in Reflectance vs. Radiance Data#
Understanding the influence of data preprocessing on feature importance is crucial in remote sensing and machine learning applications. Specifically, it’s insightful to investigate whether the preprocessing steps that convert data into radiance or reflectance values affect which bands are deemed important by the model.
Objective#
The objective of this exercise is to determine if the preprocessing transformation of data into radiance or reflectance impacts the importance assigned to different spectral bands by a machine learning model.
Preparing the Datasets#
Before comparing band importance, ensure you have prepared datasets in both radiance and reflectance. Use the dataset image_2.npy
as an example (Related notebook: Week 2 Sea ice and lead classification):
For Radiance Data#
If you directly use image_2.npy
to create the training and testing dataset, you will obtain data in radiance form, which can be used as is for the analysis.
For Reflectance Data#
To convert the chunk data into reflectance, apply the following transformation before creating the training and testing datasets:
import os
import netCDF4
import numpy as np
import re
# Define the path to the main folder where your data is stored.
# You need to replace 'path/to/data' with the actual path to your data folder.
main_folder_path = 'path/to/data'
# This part of the code is responsible for finding all directories in the main_folder that end with '.SEN3'.
# '.SEN3' is the format of the folder containing specific satellite data files (in this case, OLCI data files).
directories = [d for d in os.listdir(main_folder_path) if os.path.isdir(os.path.join(main_folder_path, d)) and d.endswith('.SEN3')]
# Loop over each directory (i.e., each set of data) found above.
for directory in directories:
# Construct the path to the OLCI data file within the directory.
# This path is used to access the data files.
OLCI_file_p = os.path.join(main_folder_path, directory, directory)
# Print the path to the current data file being processed.
# This is helpful for tracking which file is being processed at any time.
print(f"Processing: {OLCI_file_p}")
# Load the instrument data from a file named 'instrument_data.nc' inside the directory.
# This file contains various data about the instrument that captured the satellite data.
instrument_data = netCDF4.Dataset(OLCI_file_p + '/instrument_data.nc')
solar_flux = instrument_data.variables['solar_flux'][:] # Extract the solar flux data.
detector_index = instrument_data.variables['detector_index'][:] # Extract the detector index.
# Load tie geometries from a file named 'tie_geometries.nc'.
# Tie geometries contain information about viewing angles, which are important for data analysis.
tie_geometries = netCDF4.Dataset(OLCI_file_p + '/tie_geometries.nc')
SZA = tie_geometries.variables['SZA'][:] # Extract the Solar Zenith Angle (SZA).
# Create a directory for saving the processed data using the original directory name.
# This directory will be used to store output files.
save_directory = os.path.join('path/to/save', directory)
if not os.path.exists(save_directory):
os.makedirs(save_directory)
# This loop processes each radiance band in the OLCI data.
# OLCI instruments capture multiple bands, each representing different wavelengths.
OLCI_data = []
for Radiance in range(1, 22): # There are 21 bands in OLCI data.
Rstr = "%02d" % Radiance # Formatting the band number.
solar_flux_band = solar_flux[Radiance - 1] # Get the solar flux for the current band.
# Print information about the current band being processed.
# This includes the band number and its corresponding solar flux.
print(f"Processing Band: {Rstr}")
print(f"Solar Flux for Band {Rstr}: {solar_flux_band}")
# Load radiance values from the OLCI data file for the current band.
OLCI_nc = netCDF4.Dataset(OLCI_file_p + '/Oa' + Rstr + '_radiance.nc')
radiance_values = np.asarray(OLCI_nc['Oa' + Rstr + '_radiance'])
# Initialize an array to store angle data, which will be calculated based on SZA.
angle = np.zeros_like(radiance_values)
for x in range(angle.shape[1]):
angle[:, x] = SZA[:, int(x/64)]
# Calculate the Top of Atmosphere Bidirectional Reflectance Factor (TOA BRF) for the current band.
TOA_BRF = (np.pi * radiance_values) / (solar_flux_band[detector_index] * np.cos(np.radians(angle)))
# Add the calculated TOA BRF data to the OLCI_data list.
OLCI_data.append(TOA_BRF)
# Print the range of reflectance values for the current band.
print(f"Reflectance Values Range for Band {Rstr}: {np.nanmin(TOA_BRF)}, {np.nanmax(TOA_BRF)}")
# Reshape the OLCI_data array for further analysis or visualization.
reshaped_array = np.moveaxis(np.array(OLCI_data), 0, -1)
print("Reshaped array shape:", reshaped_array.shape)
# Split the reshaped array into smaller chunks along the second dimension.
# This can be useful for handling large datasets more efficiently.
split_arrays = np.array_split(reshaped_array, 5, axis=1)
# Save each chunk of data separately.
# This is helpful for processing or analyzing smaller portions of data at a time.
for i, arr in enumerate(split_arrays):
print(f"Chunk {i+1} shape:", arr.shape)
save_path = os.path.join(save_directory, f"chunk_{i+1}_band_{Rstr}.npy")
np.save(save_path, arr)
print(f"Saved Chunk {i+1} for Band {Rstr} to {save_path}")