{ "cells": [ { "cell_type": "markdown", "id": "48bcf8ef", "metadata": { "lines_to_next_cell": 0 }, "source": [ "### Inline Example of Local Expert 'Optimal Interpolation' on Satellite Data\n", "\n", "\n", "# Using Colab? Then clone and install" ] }, { "cell_type": "code", "execution_count": null, "id": "6a019228", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "try:\n", " import google.colab\n", " IN_COLAB = True\n", "except:\n", " IN_COLAB = False\n", "\n", "if IN_COLAB:\n", " import subprocess\n", " import os\n", " import re\n", "\n", " # change to working directory\n", " work_dir = \"/content\"\n", "\n", " assert os.path.exists(work_dir), f\"workspace directory: {work_dir} does not exist\"\n", " os.chdir(work_dir)\n", "\n", " # clone repository\n", " command = \"git clone https://github.com/CPOMUCL/GPSat.git\"\n", " result = subprocess.run(command.split(), capture_output=True, text=True)\n", " print(result.stdout)\n", "\n", " repo_dir = os.path.join(work_dir, \"GPSat\")\n", "\n", " print(f\"changing directory to: {repo_dir}\")\n", " os.chdir(repo_dir)\n", "\n", " # exclude certain requirements if running on colab - namely avoid installing/upgrading tensorflow\n", " new_req = []\n", " with open(os.path.join(repo_dir, \"requirements.txt\"), \"r\") as f:\n", " for line in f.readlines():\n", " # NOTE: here also removing numpy requirement\n", " if re.search(\"^tensorflow|^numpy\", line):\n", " new_req.append(\"#\" + line)\n", " else:\n", " new_req.append(line)\n", "\n", " # create a colab specific requirements file\n", " with open(os.path.join(repo_dir, \"requirements_colab.txt\"), \"w\") as f:\n", " f.writelines(new_req)\n", "\n", " # install the requirements\n", " command = \"pip install -r requirements_colab.txt\"\n", " with subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) as proc:\n", " # Stream the standard output in real-time\n", " for line in proc.stdout:\n", " print(line, end='')\n", "\n", " # install the GPSat pacakge in editable mode\n", " command = \"pip install -e .\"\n", " with subprocess.Popen(command.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) as proc:\n", " # Stream the standard output in real-time\n", " for line in proc.stdout:\n", " print(line, end='')" ] }, { "cell_type": "markdown", "id": "e69a69e2", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# Import Packages" ] }, { "cell_type": "code", "execution_count": null, "id": "a33da71e", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "\n", "import os\n", "import re\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from global_land_mask import globe\n", "from GPSat import get_data_path, get_parent_path\n", "from GPSat.dataprepper import DataPrep\n", "from GPSat.dataloader import DataLoader\n", "from GPSat.utils import stats_on_vals, WGS84toEASE2, EASE2toWGS84, cprint, grid_2d_flatten, get_weighted_values\n", "from GPSat.local_experts import LocalExpertOI, get_results_from_h5file\n", "from GPSat.plot_utils import plot_wrapper, plot_pcolormesh, get_projection, plot_pcolormesh_from_results_data, plot_hyper_parameters\n", "from GPSat.postprocessing import smooth_hyperparameters" ] }, { "cell_type": "markdown", "id": "9960913a", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# Parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "b4344a55", "metadata": {}, "outputs": [], "source": [ "\n", "# NOTE: there are parameters values that are set inline in the cells below\n", "\n", "# lat,lon center (origin) used for converting between WGS84 and EASE2 projections\n", "lat_0 = 90\n", "lon_0 = 0\n", "\n", "# expert location parameters\n", "# spacing between experts (laid out on a grid), in meters\n", "expert_spacing = 200_000\n", "# range of experts, from origin, in meters\n", "# expert_x_range = [-750_000.0, 1000_000.0]\n", "# expert_y_range = [-500_000.0, 1250_000.0]\n", "expert_x_range = [-500_000, 500_000]\n", "expert_y_range = [-500_000, 500_000]\n", "\n", "# prediction spacing\n", "# (below predictions same range as experts)\n", "pred_spacing = 5_000\n", "\n", "\n", "# model parameters\n", "# Set training and inference radius\n", "# - distance observations need to be away from expert locations to be included in training\n", "training_radius = 300_000 # 300km\n", "# - distance prediction locations need to be away from expert locations in order of predictions to be made\n", "inference_radius = 200_000 # 200km\n", "\n", "\n", "# plotting\n", "# extent = [lon min, lat max, lat min, lat max]\n", "extent = [-180, 180, 60, 90]\n", "\n", "# which projection to use: \"north\" or \"south\"\n", "projection = \"north\"" ] }, { "cell_type": "markdown", "id": "d6f45e2e", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# read in raw data\n", "\n", "add each key in col_func as a column, using a specified function + arguments\n", "values are unpacked and passed to GPSat.utils.config_func" ] }, { "cell_type": "code", "execution_count": null, "id": "9f1d5337", "metadata": {}, "outputs": [], "source": [ "\n", "\n", "df = DataLoader.read_flat_files(file_dirs=get_data_path(\"example\"),\n", " file_regex=\"_RAW\\.csv$\",\n", " col_funcs={\n", " \"source\": {\n", " \"func\": lambda x: re.sub('_RAW.*$', '', os.path.basename(x)),\n", " \"filename_as_arg\": True\n", " }\n", " })\n", "\n", "# convert lon, lat, datetime to x, y, t - to be used as the coordinate space\n", "df['x'], df['y'] = WGS84toEASE2(lon=df['lon'], lat=df['lat'], lat_0=lat_0, lon_0=lon_0)\n", "df['t'] = df['datetime'].values.astype(\"datetime64[D]\").astype(float)" ] }, { "cell_type": "markdown", "id": "3cb72996", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# stats on data" ] }, { "cell_type": "code", "execution_count": null, "id": "602f3987", "metadata": {}, "outputs": [], "source": [ "\n", "print(\"*\" * 20)\n", "print(\"summary / stats table on metric (use for trimming)\")\n", "\n", "val_col = 'z'\n", "vals = df[val_col].values\n", "stats_df = stats_on_vals(vals=vals, name=val_col,\n", " qs=[0.01, 0.05] + np.arange(0.1, 1.0, 0.1).tolist() + [0.95, 0.99])\n", "\n", "print(stats_df)" ] }, { "cell_type": "markdown", "id": "a3d1e473", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# visualise data" ] }, { "cell_type": "code", "execution_count": null, "id": "fa6368de", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "\n", "# plot observations and histogram\n", "fig, stats_df = plot_wrapper(plt_df=df,\n", " val_col=val_col,\n", " max_obs=500_000,\n", " vmin_max=[-0.1, 0.5],\n", " projection=projection,\n", " extent=extent)\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "6fee9821", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# bin raw data\n", "bin by date, source - returns a DataSet" ] }, { "cell_type": "code", "execution_count": null, "id": "f3230fbc", "metadata": {}, "outputs": [], "source": [ "\n", "bin_ds = DataPrep.bin_data_by(df=df.loc[(df['z'] > -0.35) & (df['z'] < 0.65)],\n", " by_cols=['t', 'source'],\n", " val_col=val_col,\n", " x_col='x',\n", " y_col='y',\n", " grid_res=50_000,\n", " x_range=[-4500_000.0, 4500_000.0],\n", " y_range=[-4500_000.0, 4500_000.0])\n", "\n", "# convert bin data to DataFrame\n", "# - removing all the nans that would be added at grid locations away from data\n", "bin_df = bin_ds.to_dataframe().dropna().reset_index()" ] }, { "cell_type": "markdown", "id": "e91892fb", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# plot binned data" ] }, { "cell_type": "code", "execution_count": null, "id": "574556f3", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "\n", "# this will plot all observations, some on top of each other\n", "bin_df['lon'], bin_df['lat'] = EASE2toWGS84(bin_df['x'], bin_df['y'],\n", " lat_0=lat_0, lon_0=lon_0)\n", "\n", "# plot observations and histogram\n", "fig, stats_df = plot_wrapper(plt_df=bin_df,\n", " val_col=val_col,\n", " max_obs=500_000,\n", " vmin_max=[-0.1, 0.5],\n", " projection=projection,\n", " extent=extent)\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "034a84cb", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# expert locations\n", "on evenly spaced grid" ] }, { "cell_type": "code", "execution_count": null, "id": "ca09f51d", "metadata": {}, "outputs": [], "source": [ "\n", "xy_grid = grid_2d_flatten(x_range=expert_x_range,\n", " y_range=expert_y_range,\n", " step_size=expert_spacing)\n", "\n", "# store in dataframe\n", "eloc = pd.DataFrame(xy_grid, columns=['x', 'y'])\n", "\n", "# add a time coordinate\n", "eloc['t'] = np.floor(df['t'].mean())" ] }, { "cell_type": "markdown", "id": "fe1ba441", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# plot expert locations" ] }, { "cell_type": "code", "execution_count": null, "id": "11d4c738", "metadata": {}, "outputs": [], "source": [ "\n", "eloc['lon'], eloc['lat'] = EASE2toWGS84(eloc['x'], eloc['y'],\n", " lat_0=lat_0, lon_0=lon_0)\n", "\n", "fig = plt.figure(figsize=(12, 12))\n", "ax = fig.add_subplot(1, 1, 1, projection=get_projection(projection))\n", "\n", "plot_pcolormesh(ax=ax,\n", " lon=eloc['lon'],\n", " lat=eloc['lat'],\n", " plot_data=eloc['t'],\n", " title=\"expert locations\",\n", " scatter=True,\n", " s=20,\n", " fig=fig,\n", " extent=extent)\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "75d270d9", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# prediction locations" ] }, { "cell_type": "code", "execution_count": null, "id": "2397e50f", "metadata": {}, "outputs": [], "source": [ "\n", "pred_xy_grid = grid_2d_flatten(x_range=expert_x_range,\n", " y_range=expert_y_range,\n", " step_size=pred_spacing)\n", "\n", "# store in dataframe\n", "# NOTE: the missing 't' coordinate will be determine by the expert location\n", "# - alternatively the prediction location can be specified\n", "ploc = pd.DataFrame(pred_xy_grid, columns=['x', 'y'])\n", "\n", "ploc['lon'], ploc['lat'] = EASE2toWGS84(ploc['x'], ploc['y'],\n", " lat_0=lat_0, lon_0=lon_0)\n", "\n", "# identify if a position is in the ocean (water) or not\n", "ploc[\"is_in_ocean\"] = globe.is_ocean(ploc['lat'], ploc['lon'])\n", "\n", "# keep only prediction locations in ocean\n", "ploc = ploc.loc[ploc[\"is_in_ocean\"]]" ] }, { "cell_type": "markdown", "id": "a8e3e03a", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# plot prediction locations" ] }, { "cell_type": "code", "execution_count": null, "id": "b8543e7e", "metadata": {}, "outputs": [], "source": [ "\n", "fig = plt.figure(figsize=(12, 12))\n", "ax = fig.add_subplot(1, 1, 1, projection=get_projection(projection))\n", "\n", "plot_pcolormesh(ax=ax,\n", " lon=ploc['lon'],\n", " lat=ploc['lat'],\n", " plot_data=np.full(len(ploc), 1.0), # np.arange(len(ploc)),\n", " title=\"prediction locations\",\n", " scatter=True,\n", " s=0.1,\n", " # fig=fig,\n", " extent=extent)\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "1aeac457", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# configurations:" ] }, { "cell_type": "code", "execution_count": null, "id": "e8919ed6", "metadata": {}, "outputs": [], "source": [ "\n", "# observation data\n", "data = {\n", " \"data_source\": bin_df,\n", " \"obs_col\": \"z\",\n", " \"coords_col\": [\"x\", \"y\", \"t\"],\n", " # selection criteria used for each local expert\n", " \"local_select\": [\n", " {\n", " \"col\": \"t\",\n", " \"comp\": \"<=\",\n", " \"val\": 4\n", " },\n", " {\n", " \"col\": \"t\",\n", " \"comp\": \">=\",\n", " \"val\": -4\n", " },\n", " {\n", " \"col\": [\n", " \"x\",\n", " \"y\"\n", " ],\n", " \"comp\": \"<\",\n", " \"val\": training_radius\n", " }\n", " ]\n", "}\n", "\n", "# local expert locations\n", "local_expert = {\n", " \"source\": eloc\n", "}\n", "\n", "# model\n", "model = {\n", " \"oi_model\": \"GPflowGPRModel\",\n", " \"init_params\": {\n", " # scale (divide) coordinates\n", " \"coords_scale\": [50000, 50000, 1],\n", " # can specify initial parameters values for model:\n", " # \"noise_variance\": 0.10,\n", " # \"kernel_kwargs\": {\n", " # \"lengthscales\": [2.0, 2.0, 1.0],\n", " # \"variance\": 0.05\n", " # }\n", " },\n", " # keyword arguments to be passed into each model/local expert's optimise_parameters method\n", " \"optim_kwargs\": {\n", " # parameters to be fixed (not trainable)\n", " # \"fixed_params\": [\"likelihood_variance\"]\n", " },\n", " \"constraints\": {\n", " # lengthscales - same order coord_col (see data)\n", " # - given in unscaled units\n", " \"lengthscales\": {\n", " \"low\": [1e-08, 1e-08, 1e-08],\n", " \"high\": [600000, 600000, 9]\n", " },\n", " \"likelihood_variance\": {\n", " \"low\": 0.00125,\n", " \"high\": 0.01\n", " }\n", " }\n", "}\n", "\n", "# prediction locations\n", "pred_loc = {\n", " \"method\": \"from_dataframe\",\n", " \"df\": ploc,\n", " \"max_dist\": inference_radius\n", "}" ] }, { "cell_type": "markdown", "id": "d8585ee7", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# Local Expert OI\n", "\n", "if process falls over here when calling run(), try: Runtime -> \"Restart and run all\"" ] }, { "cell_type": "code", "execution_count": null, "id": "e9c1d83d", "metadata": {}, "outputs": [], "source": [ "\n", "locexp = LocalExpertOI(expert_loc_config=local_expert,\n", " data_config=data,\n", " model_config=model,\n", " pred_loc_config=pred_loc)\n", "\n", "# run optimal interpolation\n", "# - no predictions locations supplied\n", "store_path = get_parent_path(\"results\", \"inline_example.h5\")\n", "\n", "# for the purposes of a simple example, if store_path exists: delete it\n", "if os.path.exists(store_path):\n", " cprint(f\"removing: {store_path}\", \"FAIL\")\n", " os.remove(store_path)\n", "\n", "# run optimal interpolation\n", "locexp.run(store_path=store_path,\n", " optimise=True,\n", " check_config_compatible=False)" ] }, { "cell_type": "markdown", "id": "47e438d4", "metadata": { "lines_to_next_cell": 0 }, "source": [ "results are store in hdf5" ] }, { "cell_type": "code", "execution_count": null, "id": "08fbd08f", "metadata": {}, "outputs": [], "source": [ "\n", "# extract, store in dict\n", "dfs, oi_config = get_results_from_h5file(store_path)\n", "\n", "print(f\"tables in results file: {list(dfs.keys())}\")" ] }, { "cell_type": "markdown", "id": "0bd35758", "metadata": { "lines_to_next_cell": 0 }, "source": [ "Plot Hyper Parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "1552aebc", "metadata": {}, "outputs": [], "source": [ "\n", "# a template to be used for each created plot config\n", "plot_template = {\n", " \"plot_type\": \"heatmap\",\n", " \"x_col\": \"x\",\n", " \"y_col\": \"y\",\n", " # use a northern hemisphere projection, centered at (lat,lon) = (90,0)\n", " \"subplot_kwargs\": {\"projection\": projection},\n", " \"lat_0\": lat_0,\n", " \"lon_0\": lon_0,\n", " # any additional arguments for plot_hist\n", " \"plot_kwargs\": {\n", " \"scatter\": True,\n", " },\n", " # lat/lon_col needed if scatter = True\n", " # TODO: remove the need for this\n", " \"lat_col\": \"lat\",\n", " \"lon_col\": \"lon\",\n", "}\n", "\n", "fig = plot_hyper_parameters(dfs,\n", " coords_col=oi_config[0]['data']['coords_col'], # ['x', 'y', 't']\n", " row_select=None, # this could be used to select a specific date in results data\n", " table_names=[\"lengthscales\", \"kernel_variance\", \"likelihood_variance\"],\n", " plot_template=plot_template,\n", " plots_per_row=3,\n", " suptitle=\"hyper params\",\n", " qvmin=0.01,\n", " qvmax=0.99)\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "0546e078", "metadata": { "lines_to_next_cell": 0 }, "source": [ "Smooth Hyper Parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "ff7a7543", "metadata": {}, "outputs": [], "source": [ "\n", "smooth_config = {\n", " # get hyper parameters from the previously stored results\n", " \"result_file\": store_path,\n", " # store the smoothed hyper parameters in the same file\n", " \"output_file\": store_path,\n", " # get the hyper params from tables ending with this suffix (\"\" is default):\n", " \"reference_table_suffix\": \"\",\n", " # newly smoothed hyper parameters will be store in tables ending with table_suffix\n", " \"table_suffix\": \"_SMOOTHED\",\n", " # dimension names to smooth over\n", " \"xy_dims\": [\n", " \"x\",\n", " \"y\"\n", " ],\n", " # parameters to smooth\n", " \"params_to_smooth\": [\n", " \"lengthscales\",\n", " \"kernel_variance\",\n", " \"likelihood_variance\"\n", " ],\n", " # length scales for the kernel smoother in each dimension\n", " # - as well as any min/max values to apply\n", " \"smooth_config_dict\": {\n", " \"lengthscales\": {\n", " \"l_x\": 200_000,\n", " \"l_y\": 200_000\n", " },\n", " \"likelihood_variance\": {\n", " \"l_x\": 200_000,\n", " \"l_y\": 200_000,\n", " \"max\": 0.3\n", " },\n", " \"kernel_variance\": {\n", " \"l_x\": 200_000,\n", " \"l_y\": 200_000,\n", " \"max\": 0.1\n", " }\n", " },\n", " \"save_config_file\": True\n", "}\n", "\n", "smooth_result_config_file = smooth_hyperparameters(**smooth_config)\n", "\n", "# modify the model configuration to include \"load_params\"\n", "model_load_params = model.copy()\n", "model_load_params[\"load_params\"] = {\n", " \"file\": store_path,\n", " \"table_suffix\": smooth_config[\"table_suffix\"]\n", "}\n", "\n", "locexp_smooth = LocalExpertOI(expert_loc_config=local_expert,\n", " data_config=data,\n", " model_config=model_load_params,\n", " pred_loc_config=pred_loc)\n", "\n", "# run optimal interpolation (again)\n", "# - this time don't optimise hyper parameters, but make predictions\n", "# - store results in new tables ending with '_SMOOTHED'\n", "locexp_smooth.run(store_path=store_path,\n", " optimise=False,\n", " predict=True,\n", " table_suffix=smooth_config['table_suffix'],\n", " check_config_compatible=False)" ] }, { "cell_type": "markdown", "id": "80a889c7", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# Plot Smoothed Hyper Parameters" ] }, { "cell_type": "code", "execution_count": null, "id": "9f815418", "metadata": {}, "outputs": [], "source": [ "# extract, store in dict\n", "dfs, oi_config = get_results_from_h5file(store_path)\n", "\n", "fig = plot_hyper_parameters(dfs,\n", " coords_col=oi_config[0]['data']['coords_col'], # ['x', 'y', 't']\n", " row_select=None,\n", " table_names=[\"lengthscales\", \"kernel_variance\", \"likelihood_variance\"],\n", " table_suffix=smooth_config[\"table_suffix\"],\n", " plot_template=plot_template,\n", " plots_per_row=3,\n", " suptitle=\"smoothed hyper params\",\n", " qvmin=0.01,\n", " qvmax=0.99)\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "7192a928", "metadata": { "lines_to_next_cell": 0 }, "source": [ "# get weighted combinations predictions and plot" ] }, { "cell_type": "code", "execution_count": null, "id": "3f82ab7b", "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "\n", "plt_data = dfs[\"preds\" + smooth_config[\"table_suffix\"]]\n", "\n", "weighted_values_kwargs = {\n", " \"ref_col\": [\"pred_loc_x\", \"pred_loc_y\", \"pred_loc_t\"],\n", " \"dist_to_col\": [\"x\", \"y\", \"t\"],\n", " \"val_cols\": [\"f*\", \"f*_var\"],\n", " \"weight_function\": \"gaussian\",\n", " \"lengthscale\": inference_radius/2\n", "}\n", "plt_data = get_weighted_values(df=plt_data, **weighted_values_kwargs)\n", "\n", "plt_data['lon'], plt_data['lat'] = EASE2toWGS84(plt_data['pred_loc_x'], plt_data['pred_loc_y'])\n", "\n", "fig = plt.figure(figsize=(12, 12))\n", "ax = fig.add_subplot(1, 1, 1, projection=get_projection(projection))\n", "plot_pcolormesh_from_results_data(ax=ax,\n", " dfs={\"preds\": plt_data},\n", " table='preds',\n", " val_col=\"f*\",\n", " x_col='pred_loc_x',\n", " y_col='pred_loc_y',\n", " fig=fig)\n", "plt.tight_layout()\n", "plt.show()" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "main_language": "python", "notebook_metadata_filter": "-all" } }, "nbformat": 4, "nbformat_minor": 5 }