Creating Training Data from Sentinel-2 and Sentinel-3 OLCI Data#

In our previous session, we explored the method of retrieving colocated Sentinel-2 optical data alongside Sentinel-3 OLCI data. Building upon that foundation, this week’s focus shifts towards the processing these datasets into a ML dataset we can use for regression. Our ultimate goal is to achieve pixel-level colocation between the datasets, and the application of regression techniques in task of predicitng sea ice concentration from S3/S2.

Step 0: Importing Sentinel-2 and Sentinel-3 OLCI Data#

Now, we reintroduce some previously utilised code with a pivotal objective: to import raw data and transform it into a format amenable to our analysis requirements. Given the extensive memory demands of this operation, we recommend a conceptual understanding of the process rather than direct execution.

! pip install rasterio
! pip install netCDF4
Collecting rasterio
  Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.1 kB)
Collecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl.metadata (4.0 kB)
Requirement already satisfied: attrs in /usr/local/lib/python3.11/dist-packages (from rasterio) (25.1.0)
Requirement already satisfied: certifi in /usr/local/lib/python3.11/dist-packages (from rasterio) (2025.1.31)
Requirement already satisfied: click>=4.0 in /usr/local/lib/python3.11/dist-packages (from rasterio) (8.1.8)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl.metadata (5.0 kB)
Requirement already satisfied: numpy>=1.24 in /usr/local/lib/python3.11/dist-packages (from rasterio) (1.26.4)
Collecting click-plugins (from rasterio)
  Downloading click_plugins-1.1.1-py2.py3-none-any.whl.metadata (6.4 kB)
Requirement already satisfied: pyparsing in /usr/local/lib/python3.11/dist-packages (from rasterio) (3.2.1)
Downloading rasterio-1.4.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.2 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 22.2/22.2 MB 87.3 MB/s eta 0:00:00
?25hDownloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Downloading click_plugins-1.1.1-py2.py3-none-any.whl (7.5 kB)
Installing collected packages: cligj, click-plugins, affine, rasterio
Successfully installed affine-2.4.0 click-plugins-1.1.1 cligj-0.7.2 rasterio-1.4.3
Collecting netCDF4
  Downloading netCDF4-1.7.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting cftime (from netCDF4)
  Downloading cftime-1.6.4.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.7 kB)
Requirement already satisfied: certifi in /usr/local/lib/python3.11/dist-packages (from netCDF4) (2025.1.31)
Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from netCDF4) (1.26.4)
Downloading netCDF4-1.7.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.3 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 9.3/9.3 MB 85.7 MB/s eta 0:00:00
?25hDownloading cftime-1.6.4.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.4 MB)
   ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.4/1.4 MB 72.6 MB/s eta 0:00:00
?25hInstalling collected packages: cftime, netCDF4
Successfully installed cftime-1.6.4.post1 netCDF4-1.7.2
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

#This takes ~60 seconds and ~2Gb RAM

import rasterio
import numpy as np
from sklearn.mixture import GaussianMixture
import matplotlib.pyplot as plt
from pyproj import Proj, transform

# Paths to the band images
path = '/content/drive/MyDrive/GEOL0069/2425/Week 5/' # You need to specify the path
base_path = "S2A_MSIL1C_20190301T235611_N0207_R116_T01WCU_20190302T014622.SAFE/GRANULE/L1C_T01WCU_A019275_20190301T235610/IMG_DATA/"
bands_paths = {
    'B4': path + base_path + 'T01WCU_20190301T235611_B04.jp2',
    'B3': path + base_path + 'T01WCU_20190301T235611_B03.jp2',
    'B2': path + base_path + 'T01WCU_20190301T235611_B02.jp2'
}

# Read and stack the band images
band_data = []
for band in ['B4', 'B3', 'B2']:
    with rasterio.open(bands_paths[band]) as src:
        band_data.append(src.read(1))

# Stack bands and create a mask for valid data (non-zero values in all bands)
band_stack = np.dstack(band_data)
valid_data_mask = np.all(band_stack > 0, axis=2)

# Reshape for GMM, only including valid data
X = band_stack[valid_data_mask].reshape((-1, 3))

s2_filename = path + './S2A_MSIL1C_20190301T235611_N0207_R116_T01WCU_20190302T014622.SAFE/GRANULE/L1C_T01WCU_A019275_20190301T235610/IMG_DATA/T01WCU_20190301T235611_B04.jp2'

# Read the Sentinel-2 image and its geospatial information
with rasterio.open(s2_filename) as src:
    # Read the raster data and the affine transformation
    s2_data = src.read(1)
    transform_matrix = src.transform

    # Get the spatial reference system (CRS)
    srs = src.crs

# Create grid of X,Y values
rows, cols = s2_data.shape
x_s2, y_s2 = [], []
for row in range(rows):
    print(row)
    for col in range(cols):
        x, y = transform_matrix * (col, row)
        x_s2.append(x)
        y_s2.append(y)

# Convert grid of X,Y values to latitude/longitude
in_proj = Proj(init=str(srs))  # Initialize projection from CRS
out_proj = Proj(proj='latlong')  # Initialize projection for latitude/longitude


from netCDF4 import Dataset
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
import os
import netCDF4
import numpy as np
import re
from pyproj import Proj, transform

# Define the path to the main folder where your data is stored.
# You need to replace 'path/to/data' with the actual path to your data folder.
main_folder_path = '/content/drive/MyDrive/GEOL0069/2425/Week 5'
# main_folder_path = './'
# This part of the code is responsible for finding all directories in the main_folder that end with '.SEN3'.
# '.SEN3' is the format of the folder containing specific satellite data files (in this case, OLCI data files).
directories = [d for d in os.listdir(main_folder_path) if os.path.isdir(os.path.join(main_folder_path, d)) and d.endswith('002.SEN3')] #load OLCI imagery

# Loop over each directory (i.e., each set of data) found above.
for directory in directories:
    # Construct the path to the OLCI data file within the directory.
    # This path is used to access the data files.
    OLCI_file_p = os.path.join(main_folder_path, directory)

    # Print the path to the current data file being processed.
    # This is helpful for tracking which file is being processed at any time.
    print(f"Processing: {OLCI_file_p}")

    # Load the instrument data from a file named 'instrument_data.nc' inside the directory.
    # This file contains various data about the instrument that captured the satellite data.
    instrument_data = netCDF4.Dataset(OLCI_file_p + '/instrument_data.nc')
    solar_flux = instrument_data.variables['solar_flux'][:]  # Extract the solar flux data.
    detector_index = instrument_data.variables['detector_index'][:]  # Extract the detector index.

    # Load tie geometries from a file named 'tie_geometries.nc'.
    # Tie geometries contain information about viewing angles, which are important for data analysis.
    tie_geometries = netCDF4.Dataset(OLCI_file_p + '/tie_geometries.nc')
    SZA = tie_geometries.variables['SZA'][:]  # Extract the Solar Zenith Angle (SZA).

    # Create a directory for saving the processed data using the original directory name.
    # This directory will be used to store output files.
    save_directory = os.path.join('path/to/save', directory)
    if not os.path.exists(save_directory):
        os.makedirs(save_directory)

    # This loop processes each radiance band in the OLCI data.
    # OLCI instruments capture multiple bands, each representing different wavelengths.
    OLCI_data = []
    for Radiance in range(1, 22):  # There are 21 bands in OLCI data.
    # for Radiance in [2,5,8,16]:  # selecting relevant bands

        Rstr = "%02d" % Radiance  # Formatting the band number.
        solar_flux_band = solar_flux[Radiance - 1]  # Get the solar flux for the current band.

        # Print information about the current band being processed.
        # This includes the band number and its corresponding solar flux.
        print(f"Processing Band: {Rstr}")
        print(f"Solar Flux for Band {Rstr}: {solar_flux_band}")

        # Load radiance values from the OLCI data file for the current band.
        OLCI_nc = netCDF4.Dataset(OLCI_file_p + '/Oa' + Rstr + '_radiance.nc')
        radiance_values = np.asarray(OLCI_nc['Oa' + Rstr + '_radiance'])

        # Initialize an array to store angle data, which will be calculated based on SZA.
        angle = np.zeros_like(radiance_values)
        for x in range(angle.shape[1]):
            angle[:, x] = SZA[:, int(x/64)]

        # Calculate the Top of Atmosphere Bidirectional Reflectance Factor (TOA BRF) for the current band.
        TOA_BRF = (np.pi * radiance_values) / (solar_flux_band[detector_index] * np.cos(np.radians(angle)))

        # Add the calculated TOA BRF data to the OLCI_data list.
        OLCI_data.append(TOA_BRF)

    reshaped_array = np.moveaxis(np.array(OLCI_data), 0, -1)
    OLCI_coord = netCDF4.Dataset(OLCI_file_p + '/geo_coordinates.nc')
    OLCI_lon=OLCI_coord['longitude']
    OLCI_lat=OLCI_coord['latitude']

x_s3, y_s3 = transform(out_proj, in_proj, OLCI_lon, OLCI_lat)

# This takes 10 minutes to run so we may not run this
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches


# Reshape S2 Coordinates
rows_s2, cols_s2 = s2_data.shape  # S2 resolution (e.g., 10980 x 10980)
x_s2_arr = np.asarray(x_s2).reshape(rows_s2, cols_s2)
y_s2_arr = np.asarray(y_s2).reshape(rows_s2, cols_s2)

x_s3_arr = np.asarray(x_s3)
y_s3_arr = np.asarray(y_s3)


# Compute the Extents
# S2 extents
s2_x_min, s2_x_max = x_s2_arr.min(), x_s2_arr.max()
s2_y_min, s2_y_max = y_s2_arr.min(), y_s2_arr.max()

# S3 extents
s3_x_min, s3_x_max = x_s3_arr.min(), x_s3_arr.max()
s3_y_min, s3_y_max = y_s3_arr.min(), y_s3_arr.max()

print("S2 Extents:")
print("   x: ", s2_x_min, "to", s2_x_max)
print("   y: ", s2_y_min, "to", s2_y_max)

print("\nS3 Extents:")
print("   x: ", s3_x_min, "to", s3_x_max)
print("   y: ", s3_y_min, "to", s3_y_max)


# Compute the Intersection (Overlapping Area)

overlap_x_min = max(s2_x_min, s3_x_min)
overlap_x_max = min(s2_x_max, s3_x_max)
overlap_y_min = max(s2_y_min, s3_y_min)
overlap_y_max = min(s2_y_max, s3_y_max)

print("\nOverlap Region Extents:")
print("   x: ", overlap_x_min, "to", overlap_x_max)
print("   y: ", overlap_y_min, "to", overlap_y_max)


# Create Conditions for Overlapping Pixels

condition_s2 = (
    (x_s2_arr >= overlap_x_min) & (x_s2_arr <= overlap_x_max) &
    (y_s2_arr >= overlap_y_min) & (y_s2_arr <= overlap_y_max)
)

condition_s3 = (
    (x_s3_arr >= overlap_x_min) & (x_s3_arr <= overlap_x_max) &
    (y_s3_arr >= overlap_y_min) & (y_s3_arr <= overlap_y_max)
)


# Plot the Overlap Regions
# Plot for S3 (OLCI)
plt.figure(figsize=(8, 6))
scatter1 = plt.scatter(
    x_s3_arr[condition_s3],
    y_s3_arr[condition_s3],
    c=reshaped_array[condition_s3, 0],
    cmap='viridis',
    vmin=0.7, vmax=1,
    s=10
)
plt.colorbar(scatter1, label='Reflectance (Band 0)')
plt.title('S3 Overlap Region with S2')
plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')

# Add the red dotted rectangle for the zoomed region
ax = plt.gca()  # get the current axes

zoom_x_min = 350000.0
zoom_x_max = 390000.0
zoom_y_min = 7790230.0
zoom_y_max = 7820000.0

# Rectangle width and height
rect_width = zoom_x_max - zoom_x_min
rect_height = zoom_y_max - zoom_y_min

# Create a Rectangle patch in data coordinates
rect = patches.Rectangle(
    (zoom_x_min, zoom_y_min),
    rect_width,
    rect_height,
    linewidth=2,
    edgecolor='r',
    facecolor='none',
    linestyle='--'
)

# Add the patch to the Axes
ax.add_patch(rect)
plt.savefig('S3_overlap_region_with_zoom.png', dpi=300, bbox_inches='tight')
plt.show()



plt.figure(figsize=(8, 6))
scatter2 = plt.scatter(
    x_s2_arr[condition_s2],
    y_s2_arr[condition_s2],
    c=band_stack[condition_s2, 2] / 10000,
    cmap='viridis',
    vmin=0.7, vmax=1,
    s=1
)
plt.colorbar(scatter2, label='Reflectance (B2/10000)')
plt.title('S2 Overlap Region')
plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')

# Add the same rectangle for the S2 plot
ax2 = plt.gca()
rect2 = patches.Rectangle(
    (zoom_x_min, zoom_y_min),
    rect_width,
    rect_height,
    linewidth=2,
    edgecolor='r',
    facecolor='none',
    linestyle='--'
)
ax2.add_patch(rect2)
plt.savefig('S2_overlap_region_with_zoom.png', dpi=300, bbox_inches='tight')
plt.show()

The output of the above cell is shown below. These are 2 plots showing the co-located region in S3 and S2 format.

# No need to run this
from IPython.display import Image, display
display(Image(filename='/content/S2_overlap_region_with_zoom.png', width=600))
display(Image(filename='/content/S3_overlap_region_with_zoom.png', width=600))
_images/e4783c570ae81a86000bdf9a65c877e6401f951f6b5b6879e04329a49018555f.png _images/cc722f30d39905157d1affcd14a2e3aeb974327a4aa37825c81f9eea1750b35e.png

Now we select a subregion and our analysis will be based on that. We saved them into numpy arrays for further analysis. The cell below will save the subregion data in both S3 and S2 format.

import numpy as np
import matplotlib.pyplot as plt

save_path = '/content/drive/MyDrive/GEOL0069/2425/Week 5/Regression_application'
rows_s2, cols_s2 = s2_data.shape
x_s2_arr = np.asarray(x_s2).reshape(rows_s2, cols_s2)
y_s2_arr = np.asarray(y_s2).reshape(rows_s2, cols_s2)

x_s3_arr = np.asarray(x_s3)
y_s3_arr = np.asarray(y_s3)

s2_x_min, s2_x_max = x_s2_arr.min(), x_s2_arr.max()
s2_y_min, s2_y_max = y_s2_arr.min(), y_s2_arr.max()

# S3 extents
s3_x_min, s3_x_max = x_s3_arr.min(), x_s3_arr.max()
s3_y_min, s3_y_max = y_s3_arr.min(), y_s3_arr.max()

print("S2 Extents:")
print("   x:", s2_x_min, "to", s2_x_max)
print("   y:", s2_y_min, "to", s2_y_max)

print("\nS3 Extents:")
print("   x:", s3_x_min, "to", s3_x_max)
print("   y:", s3_y_min, "to", s3_y_max)

# Compute the Intersection (Overlapping Area) from the original extents
overlap_x_min = max(s2_x_min, s3_x_min)
overlap_x_max = min(s2_x_max, s3_x_max)
overlap_y_min = max(s2_y_min, s3_y_min)
overlap_y_max = min(s2_y_max, s3_y_max)

print("\nOverlap Region Extents:")
print("   x:", overlap_x_min, "to", overlap_x_max)
print("   y:", overlap_y_min, "to", overlap_y_max)

condition_s2 = (
    (x_s2_arr >= overlap_x_min) & (x_s2_arr <= overlap_x_max) &
    (y_s2_arr >= overlap_y_min) & (y_s2_arr <= overlap_y_max)
)

condition_s3 = (
    (x_s3_arr >= overlap_x_min) & (x_s3_arr <= overlap_x_max) &
    (y_s3_arr >= overlap_y_min) & (y_s3_arr <= overlap_y_max)
)


zoom_x_min = 350000.0
zoom_x_max = 390000.0
zoom_y_min = 7790230.0
zoom_y_max = 7820000.0

condition_zoom_s2 = condition_s2 & (
    (x_s2_arr >= zoom_x_min) & (x_s2_arr <= zoom_x_max) &
    (y_s2_arr >= zoom_y_min) & (y_s2_arr <= zoom_y_max)
)

condition_zoom_s3 = condition_s3 & (
    (x_s3_arr >= zoom_x_min) & (x_s3_arr <= zoom_x_max) &
    (y_s3_arr >= zoom_y_min) & (y_s3_arr <= zoom_y_max)
)


s2_zoom_x = x_s2_arr[condition_zoom_s2]
s2_zoom_y = y_s2_arr[condition_zoom_s2]
s2_zoom_data = band_stack[condition_zoom_s2, :]

# For S3
s3_zoom_x = x_s3_arr[condition_zoom_s3]
s3_zoom_y = y_s3_arr[condition_zoom_s3]
s3_zoom_data = reshaped_array[condition_zoom_s3, :]

np.savez(save_path+'/s2_zoomed_data.npz', x=s2_zoom_x, y=s2_zoom_y, band_data=s2_zoom_data)
np.savez(save_path+'/s3_zoomed_data.npz', x=s3_zoom_x, y=s3_zoom_y, reflectance=s3_zoom_data)


Step 1: Loading data needed#

We load the saved S2 and S3 data here.

import numpy as np
import matplotlib.pyplot as plt

save_path = '/content/drive/MyDrive/GEOL0069/2425/Week 5/Regression_application'
s2_data = np.load(save_path+'/s2_zoomed_data.npz')
s2_x = s2_data['x']
s2_y = s2_data['y']
s2_band_data = s2_data['band_data']

s3_data = np.load(save_path+'/s3_zoomed_data.npz')
s3_x = s3_data['x']
s3_y = s3_data['y']
s3_reflectance = s3_data['reflectance']

plt.figure(figsize=(8, 6))
scatter_s2 = plt.scatter(
    s2_x, s2_y,
    c=s2_band_data[:, 2] / 10000,
    cmap='viridis',
    s=1
)
plt.colorbar(scatter_s2, label='Reflectance (B2/10000)')
plt.title('Zoomed S2 Overlap Region')
plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')
plt.show()


plt.figure(figsize=(8, 6))
scatter_s3 = plt.scatter(
    s3_x, s3_y,
    c=s3_reflectance[:, 0],
    cmap='viridis',
    s=10
)
plt.colorbar(scatter_s3, label='Reflectance (Band 0)')
plt.title('Zoomed S3 Overlap Region')
plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')
plt.show()
_images/cdaa52a6a0bb289f5985f9ec9c329bb9a0ce116a67f0c6879977cb680fa5d9d2.png _images/32736e2d31b365c41c3001eeddcd12660cc07060b98f27eb353c58ce1345363d.png

Step 2: Label S2 pixels (See week 4 on unsupervised classification)#

For the code below, we use K-Means clustering to get the labels of Sentinel-2 image and we will use them to generate part of the training data.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from matplotlib.colors import ListedColormap

kmeans = KMeans(n_clusters=2, random_state=0)
labels = kmeans.fit_predict(s2_band_data)

unique_x = np.sort(np.unique(s2_x))
unique_y = np.sort(np.unique(s2_y))

num_cols = unique_x.size
num_rows = unique_y.size

N = s2_band_data.shape[0]

binary_mask = labels.reshape(num_rows, num_cols)

if kmeans.cluster_centers_[0].mean() < kmeans.cluster_centers_[1].mean():
    binary_mask = 1 - binary_mask

plt.figure(figsize=(8, 5))
binary_cmap = ListedColormap(["black", "white"])
plt.imshow(binary_mask, cmap=binary_cmap, origin='upper',
           extent=[unique_x.min(), unique_x.max(), unique_y.min(), unique_y.max()])
plt.colorbar(ticks=[0, 1], label="Class", pad=0.02)
plt.clim(-0.5, 1.5)
plt.title('Binary K-means Segmentation of Zoomed-in S2 Data')
plt.xlabel('X Coordinate')
plt.ylabel('Y Coordinate')
plt.show()
_images/70ae981084752086d515a50e3098734b31c6eb667a758a30c3b5c28534127d39.png

Step 3: Find collocated pixels using KDTree#

The code constructs a KDTree from S3 coordinates to quickly find the nearest S3 pixel for each S2 data point based on their (x, y) positions. It then groups S2 points by their nearest S3 neighbor and computes an aggregated value (sea ice concentration) for each S3 pixel. In the end, we create a dataset that we could use for ML activities.

The logic of our KDTree is for each Sentinel‑2 pixel to find its nearest Sentinel‑3 pixel, ensuring that every fine-resolution S2 measurement is uniquely assigned to a single coarse-resolution S3 pixel—this approach avoids the pitfalls of using a fixed radius around each S3 pixel, which could potentially miss some S2 pixels or count them more than once. A illustration is shown here.

# No need to run this
from IPython.display import Image, display
display(Image(filename='/content/S3S2KDtree.png', width=600))
_images/95eded8326725fe7ed1b01d23c1925b6a7ebea138a7c2606dbaf174c567e0d53.png
from scipy.spatial import KDTree
from collections import defaultdict


s2_points = np.vstack((s2_x, s2_y)).T
s3_points = np.vstack((s3_x, s3_y)).T

tree = KDTree(s3_points)

distances, s3_indices_for_s2 = tree.query(s2_points)

grouped = defaultdict(list)
for s2_idx, s3_idx in enumerate(s3_indices_for_s2):
    grouped[s3_idx].append(s2_idx)

SICavg = np.full(len(s3_points), np.nan)

for s3_idx, s2_indices in grouped.items():
    SICavg[s3_idx] = 1 - np.mean(labels[s2_indices])
# Plot the distaces as a histogram
fig, ax = plt.subplots(figsize=(8, 6))
ax.hist(distances, bins = 100)
ax.set_xlabel('Distance (m)')
ax.set_ylabel('Frequency')
ax.set_title('S2 to nearest S3 distance histogram')
ax.text(0.8, 0.9, f'Mean distance: {np.mean(distances):.2f} m', transform=ax.transAxes, ha='center')
plt.show()
_images/b1449cdb39470c1205ea03639d6bb8883f57e377b32f11109e2f0393d47992ad.png
plt.figure(figsize=(8, 6))
scatter = plt.scatter(s3_points[:, 0], s3_points[:, 1],
                      c=SICavg, cmap='viridis', s=20)
plt.colorbar(scatter, label='Sea Ice Concentration (SIC)')
plt.title("S3 Points with Sea Ice Concentration Derived from Co-located S2")
plt.xlabel("X Coordinate")
plt.ylabel("Y Coordinate")
plt.show()
_images/6cafd01e36232c9d1634e9325c0991041f1c04642bfc38a1f50a395a621423aa.png
import numpy as np
import pandas as pd


valid_mask = ~np.isnan(SICavg)
X_valid = s3_reflectance[valid_mask, :]
y_valid = SICavg[valid_mask]

n_bands = s3_reflectance.shape[1]
column_names = [f'refl_band_{i+1}' for i in range(n_bands)]
df = pd.DataFrame(X_valid, columns=column_names)
df['SIC'] = y_valid
save_path = '/content/drive/MyDrive/GEOL0069/2425/Week 5/Regression_application'
np.savez(save_path+'/s3_ML_dataset.npz', X=X_valid, y=y_valid)

df.to_csv(save_path+'/s3_ML_dataset.csv', index=False)

print("ML dataset created and saved!")
print("Feature shape:", X_valid.shape)
print("Target shape:", y_valid.shape)
ML dataset created and saved!
Feature shape: (14445, 21)
Target shape: (14445,)