Bonus Tutorial 7: Deep Learning for Climate Prediction with CNN-LSTMs (PyTorch)#

Week 2, Day 4, AI and Climate Change

Content creators: Deepak Mewada, Grace Lindsay

Content reviewers: Jenna Pearson

Content editors: Deepak Mewada, Grace Lindsay

Production editors: Jenna Pearson, Konstantine Tsafatinos

Our 2024 Sponsors: CMIP, NFDI4Earth

Tutorial Objectives#

Estimated timing of tutorial: 60 minutes

Welcome back! You’ve skillfully applied scikit-learn to climate modeling in Tutorial 1 and Tutorial 2. Now, get ready to dive into the world of Deep Learning using PyTorch! This tutorial focuses on a Convolutional Neural Network (CNN) combined with a Long Short-Term Memory (LSTM) network, a powerful architecture for spatiotemporal data.

In this tutorial, you will learn

  • Deep Learning Fundamentals

  • PyTorch Primer

  • Climate Data in Tensors

  • Defining the DL model - CNN-LSTM

  • Training the Model

  • Making Prediction from the trained model

#Setup

!pip install cartopy xarray --quiet
import numpy as np # Numerical computing
import xarray as xr # Labeled multi-dimensional arrays
import pandas as pd # Data analysis and manipulation
import cartopy.crs as ccrs # Geospatial plotting
import matplotlib.pyplot as plt # Plotting
from types import MethodType
from IPython.display import clear_output
import types

import torch # PyTorch!
import torch.nn as nn # Neural network layers
import torch.optim as optim # Optimization algorithms
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

import random # Random number generation

#from tqdm import tqdm
from ipywidgets import interact, IntSlider
import ipywidgets as widgets
import plotly.express as px

import os
import pooch
import contextlib
import sys

Install and import feedback gadget#

Hide code cell source
# @title Install and import feedback gadget

!pip3 install vibecheck datatops --quiet

from vibecheck import DatatopsContentReviewContainer
def content_review(notebook_section: str):
    return DatatopsContentReviewContainer(
        "",  # No text prompt
        notebook_section,
        {
            "url": "https://pmyvdlilci.execute-api.us-east-1.amazonaws.com/klab",
            "name": "comptools_4clim",
            "user_key": "l5jpxuee",
        },
    ).render()


feedback_prefix = "W2D4_T7"

Figure settings#

Hide code cell source
# @title Figure settings
import ipywidgets as widgets  # interactive display

%config InlineBackend.figure_format = 'retina'
plt.style.use(
    "https://raw.githubusercontent.com/neuromatch/climate-course-content/main/cma.mplstyle"
)

Run this cell for Data retrieval#

Hide code cell source
# @title Run this cell for Data retrieval

#Data retrieval
#This cell downloads the temporary copy of the dataset that we will use in this tutorial to the local RAM of colab--> you may check the data from the folder icon in left side bar

# Mapping of filenames to OSF download codes (you need to fill this fully)
file_url_map_train_val = {
    'inputs_historical.nc': 'kqxet',
    'outputs_historical.nc': 'une23',
    'inputs_ssp126.nc': 'jvqg5',
    'outputs_ssp126.nc': '9jmsy',
    #'inputs_ssp245.nc': 'hqvkz',
    #'outputs_ssp245.nc': 'k7fqu',
    'inputs_ssp370.nc': '4snxb',
    'outputs_ssp370.nc': 'zcafm',
    'inputs_ssp585.nc': 'sejxt',
    'outputs_ssp585.nc': 'vwg39',
    'inputs_hist-GHG.nc': 'p84hg',
    'outputs_hist-GHG.nc': 'ys7nu',
    'inputs_hist-aer.nc': 'q7skr',
    'outputs_hist-aer.nc': 'bq3k8',
}

osf_base_url = "https://osf.io/download/"
target_dir_train_val = "Data/train_val/"
os.makedirs(target_dir_train_val, exist_ok=True)

# Context manager to suppress stdout and stderr
@contextlib.contextmanager
def suppress_output():
    with open(os.devnull, 'w') as devnull:
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        try:
            sys.stdout = devnull
            sys.stderr = devnull
            yield
        finally:
            sys.stdout = old_stdout

def download_osf_files(file_map, target_dir):
    for filename, code in file_map.items():
        url = osf_base_url + code + "/"
        _ = pooch.retrieve(
            url=url,
            known_hash=None,   # Skip hash check
            fname=filename,
            path=target_dir,
            progressbar=False
        )

# Download all train_val NetCDF files
download_osf_files(file_url_map_train_val, target_dir_train_val)


# --- SETUP: Download test files into temp/test/ ---

file_url_map_test = {
    'inputs_ssp245.nc': '8gpvw',
    'outputs_ssp245.nc': '9pmtx'
}

target_dir_test = "Data/test/"
os.makedirs(target_dir_test, exist_ok=True)

# Download all text files
download_osf_files(file_url_map_test, target_dir_test)
Downloading data from 'https://osf.io/download/kqxet/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/train_val/inputs_historical.nc'.
SHA256 hash of downloaded file: 85f206cda4846841c3b6a7814961682125b8239834ef4007eb1c8fadb143ba19
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://osf.io/download/une23/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/train_val/outputs_historical.nc'.
SHA256 hash of downloaded file: 28df86a8a3131289d99ff661e78f52a15071a141e96b22c2d8c6542cc5e6b2a3
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://osf.io/download/jvqg5/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/train_val/inputs_ssp126.nc'.
SHA256 hash of downloaded file: 7367c1b98814967be0a7bb400b30c26cf084f7a44a24445b7f1a56240232b3a9
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://osf.io/download/9jmsy/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/train_val/outputs_ssp126.nc'.
SHA256 hash of downloaded file: ac67875e1c9ffa72d3e50ec1d3bc6ad52a038662af5f9904094f95e36ff7684f
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://osf.io/download/4snxb/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/train_val/inputs_ssp370.nc'.
SHA256 hash of downloaded file: 7899e2a7c0e8140deeecc7d5fca928a2b35651f1ed6f55ceb5421cef4a148288
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://osf.io/download/zcafm/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/train_val/outputs_ssp370.nc'.
SHA256 hash of downloaded file: ec6a7375ca87c639901bd7b05f894b16b6f0aed3382f8dd386ad1a62d7debc86
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://osf.io/download/sejxt/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/train_val/inputs_ssp585.nc'.
SHA256 hash of downloaded file: 589e32641ef3b795d96df0248bc84688ae4d70fffa280f6d0de5ee3b97c0081c
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://osf.io/download/vwg39/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/train_val/outputs_ssp585.nc'.
SHA256 hash of downloaded file: c3cddb3e24e911f68e3dbe3306a52680bc7e43afd8814d6ba23cb3e106d4e2b3
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://osf.io/download/p84hg/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/train_val/inputs_hist-GHG.nc'.
SHA256 hash of downloaded file: 18eadd4f0eba29bf4fe62cd12800aee7ff32b38dfce07476169bad467d8146f5
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://osf.io/download/ys7nu/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/train_val/outputs_hist-GHG.nc'.
SHA256 hash of downloaded file: 5e4871ad64041d8c1b2ba1ee34786fc7575ca23cfcce3a1c8d730c53e87f44ca
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://osf.io/download/q7skr/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/train_val/inputs_hist-aer.nc'.
SHA256 hash of downloaded file: 2b2183eed0dc6c2fb743d9b4cf96a1922d60cf4d891893d32d2c1f9c2bc48a27
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://osf.io/download/bq3k8/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/train_val/outputs_hist-aer.nc'.
SHA256 hash of downloaded file: 1c846ad72282c1481101209077dd328519b62fab0e196bfba649adace03e7a4f
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://osf.io/download/8gpvw/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/test/inputs_ssp245.nc'.
SHA256 hash of downloaded file: 010f355d400fdec477d38890b7b253a7ad672ad4c6a60b93c8196f7dbabc8647
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'https://osf.io/download/9pmtx/' to file '/home/runner/work/climate-course-content/climate-course-content/tutorials/W2D4_AIandClimateChange/student/Data/test/outputs_ssp245.nc'.
SHA256 hash of downloaded file: dad3689638223e1b0b5e97c5aa665066737770faf4e4edc5281e01f61ae87122
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.

Helper functions {“run”:“auto”,“display-mode”:“form”}#

Hide code cell source
# @title Helper functions  {"run":"auto","display-mode":"form"}

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, display, clear_output
import ipywidgets as widgets
from ipywidgets import interact
import warnings
from matplotlib.colors import LogNorm
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", message=".*FrozenMappingWarningOnValuesAccess.*")
# Ensure interactive mode
%matplotlib inline

def plot_climate_heatmap(Y_train, climate_var='pr'):
    """
    Creates an interactive heatmap for visualizing climate variables over time.

    Parameters:
    - Y_train (xarray.Dataset): Dataset containing climate variables.
    - climate_var (str): Variable to visualize (default: 'pr' for precipitation).
    """
    climate_data = Y_train[0][climate_var]  # Extract data for the first simulation

    def plot_data(time_step=0):
        fig = px.imshow(
            climate_data.isel(time=time_step).values,
            color_continuous_scale='viridis',
            labels={'x': "Longitude", 'y': "Latitude"},
            title=f"{climate_var.upper()} at Time Step {time_step}"
        )
        fig.show()

    return widgets.interactive(plot_data, time_step=(0, climate_data.sizes['time'] - 1, 1))


def plot_climate_timeseries(Y_train, climate_var='pr', latitude=50.0, longitude=-120.0):
    """
    Creates an interactive time series plot for a specific location.

    Parameters:
    - Y_train (xarray.Dataset): Climate dataset.
    - climate_var (str): Climate variable to visualize.
    - latitude (float): Latitude of the location.
    - longitude (float): Longitude of the location.
    """
    # Find the closest grid point
    lat_idx = np.abs(Y_train[0]['latitude'] - latitude).argmin()
    lon_idx = np.abs(Y_train[0]['longitude'] - longitude).argmin()

    # Extract time series data for this location
    climate_time_series = Y_train[0][climate_var][:, lat_idx, lon_idx]

    # Create the interactive plot
    fig = px.line(
        x=Y_train[0]['time'],
        y=climate_time_series,
        labels={'x': "Time", 'y': f"{climate_var.upper()}"},
        title=f"{climate_var.upper()} Time Series at ({latitude}, {longitude})"
    )
    fig.show()

def interactive_variable_selector(Y_train, plot_function):
    """
    Creates an interactive dropdown menu to select a climate variable
    and updates the visualization accordingly.

    Parameters:
    - Y_train (xarray.Dataset): Dataset containing climate variables.
    - plot_function (function): Function to visualize the selected climate variable.
    """
    variable_selector = widgets.Dropdown(
        options=list(Y_train[0].data_vars.keys()),
        description="Variable:"
    )

    def update_variable(selected_var):
        plot_function(Y_train, selected_var)  # Call the plotting function with the selected variable

    return widgets.interactive(update_variable, selected_var=variable_selector)

def compare_inputs_outputs(X_train_torch, Y_train_torch):
    """
    Creates an interactive widget to compare input climate variables with
    the predicted surface air temperature (TAS).

    Parameters:
    - X_train_torch (torch.Tensor): Input climate variables (samples, time, variables, height, width)
    - Y_train_torch (torch.Tensor): Target temperature values (samples, time, height, width)
    """

    def plot_sample(sample_idx, time_step):
        """
        Helper function to plot climate variables and TAS for a given sample and time step.
        """
        input_sample = X_train_torch[sample_idx, time_step].cpu().numpy()  # Shape: (4, 96, 144)
        output_sample = Y_train_torch[sample_idx, 0].cpu().numpy()  # Shape: (96, 144)
        variables = ["CO₂", "CH₄", "SO₂", "Black Carbon"]

        fig, axes = plt.subplots(1, 5, figsize=(20, 4))

        for i in range(4):
            im = axes[i].imshow(input_sample[i], cmap="coolwarm", origin="lower")
            axes[i].set_title(variables[i])
            fig.colorbar(im, ax=axes[i], shrink=0.6)

        # Plot output TAS
        im = axes[4].imshow(output_sample, cmap="coolwarm", origin="lower")
        axes[4].set_title("Surface Air Temperature (TAS)")
        fig.colorbar(im, ax=axes[4], shrink=0.6)

        plt.suptitle(f"Comparison at Time Step {time_step}, Sample {sample_idx}")
        plt.tight_layout()
        plt.show()

    # Interactive Widget
    interact(plot_sample,
             sample_idx=widgets.IntSlider(min=0, max=X_train_torch.shape[0]-1, step=1, value=0, description="Sample"),
             time_step=widgets.IntSlider(min=0, max=X_train_torch.shape[1]-1, step=1, value=0, description="Time Step"))


def animate_climate_variables(X_train_torch, sample_idx=0, scale_mode='auto'):
    """
    Creates an interactive animation to visualize selected climate input variables
    (CH₄, Black Carbon) over time, with optional log or linear scaling.
    """

    # Clear previous figures and outputs
    plt.close('all')
    clear_output(wait=True)

    # Extract input variables for the given sample
    input_seq = X_train_torch[sample_idx].cpu().numpy()  # Shape: (time, variables, height, width)

    # Only use CH₄ (1) and Black Carbon (3)
    selected_indices = [1, 3]
    variables = ["CH₄", "Black Carbon"]

    # Prepare selected input
    input_seq = input_seq[:, selected_indices, :, :]  # Now shape is (time, 2, H, W)

    # Create figure and axes
    fig, axes = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)

    # Precompute color scale limits or normalization
    vmins, vmaxs, norms = [], [], []
    for i in range(len(variables)):
        data = input_seq[:, i, :, :]
        data_flat = data.flatten()
        all_positive = np.all(data_flat > 0)

        if scale_mode == 'log' and all_positive:
            norm = LogNorm(vmin=np.percentile(data_flat, 5), vmax=np.percentile(data_flat, 95))
            norms.append(norm)
            vmins.append(None)
            vmaxs.append(None)
        elif scale_mode == 'auto' and all_positive:
            norm = LogNorm(vmin=np.percentile(data_flat, 5), vmax=np.percentile(data_flat, 95))
            norms.append(norm)
            vmins.append(None)
            vmaxs.append(None)
        else:
            norm = None
            norms.append(norm)
            vmins.append(np.percentile(data_flat, 5))
            vmaxs.append(np.percentile(data_flat, 95))

    # Initialize image plots
    ims = []
    for i, ax in enumerate(axes):
        if norms[i]:
            im = ax.imshow(input_seq[0, i], cmap="coolwarm", origin="lower", norm=norms[i])
        else:
            im = ax.imshow(input_seq[0, i], cmap="coolwarm", origin="lower", vmin=vmins[i], vmax=vmaxs[i])
        ax.set_title(f"{variables[i]} (Year 1)")
        fig.colorbar(im, ax=ax, shrink=0.6)
        ims.append(im)

    # Animation update function
    def update(frame):
        for i, im in enumerate(ims):
            im.set_data(input_seq[frame, i])
            axes[i].set_title(f"{variables[i]} (Year {frame+1})")

    ani = FuncAnimation(fig, update, frames=input_seq.shape[0], interval=500, repeat=True)

    clear_output(wait=True)
    display(HTML(ani.to_jshtml()))


#  1. Live Loss & Validation Tracking
def plot_loss():
    plt.figure(figsize=(8, 5))
    plt.plot(range(1, len(train_losses) + 1), train_losses, label="Training Loss", marker="o")
    plt.plot(range(1, len(val_losses) + 1), val_losses, label="Validation Loss", marker="s")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss Over Time")
    plt.legend()
    plt.grid()
    plt.show()

#  2. Weight & Gradient Evolution (Histogram)
def plot_weight_gradients(epoch):
    if epoch < len(weights_history):
        weights = np.concatenate([w.flatten() for w in weights_history[epoch]])
        grads = np.concatenate([g.flatten() for g in grads_history[epoch]]) if grads_history[epoch] else None

        fig, ax = plt.subplots(1, 2, figsize=(12, 5))

        ax[0].hist(weights, bins=50, color="blue", alpha=0.7)
        ax[0].set_title(f"Model Weights Distribution (Epoch {epoch+1})")
        ax[0].set_xlabel("Weight Value")
        ax[0].set_ylabel("Frequency")

        if grads is not None:
            ax[1].hist(grads, bins=50, color="red", alpha=0.7)
            ax[1].set_title(f"Gradient Distribution (Epoch {epoch+1})")
            ax[1].set_xlabel("Gradient Value")
            ax[1].set_ylabel("Frequency")

        plt.show()

#  3. Sample Predictions Over Time (Slider)
def plot_predictions1(epoch):
    cnn_model.eval()
    with torch.no_grad():
        X_input = X_train_torch[:10].to(next(cnn_model.parameters()).device)
        Y_pred = cnn_model(X_input).cpu().numpy()

        #Y_pred = cnn_model(X_train_torch[:10]).cpu().numpy()
        Y_true = Y_train_torch[:10].cpu().numpy()

    plt.figure(figsize=(8, 5))
    plt.plot(Y_true[epoch].flatten(), label="Ground Truth", marker="o")
    plt.plot(Y_pred[epoch].flatten(), label="Predicted", linestyle="dashed", marker="x")
    plt.xlabel("Time Steps")
    plt.ylabel("Climate Variable")
    plt.title(f"Predictions vs. Ground Truth (Sample {epoch+1})")
    plt.legend()
    plt.grid()
    plt.show()

def plot_predictions(y_true, y_pred, title):
    """Plots predicted vs. actual values as spatial maps."""
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    vmin, vmax = np.percentile(y_true, [5, 95])

    axes[0].imshow(y_true.squeeze(), cmap='coolwarm', vmin=vmin, vmax=vmax)
    axes[0].set_title('Ground Truth')
    axes[0].axis('off')

    axes[1].imshow(y_pred.squeeze(), cmap='coolwarm', vmin=vmin, vmax=vmax)
    axes[1].set_title('Prediction')
    axes[1].axis('off')

    plt.show()

#for widgets
def compare_inputs_outputs(sample_idx, time_step):
    """
    Compare selected input climate variables (CH₄, Black Carbon) with predicted temperature change (TAS).
    """
    input_sample = X_train_torch[sample_idx, time_step].cpu().numpy()  # (4, 96, 144)
    output_sample = Y_train_torch[sample_idx, 0].cpu().numpy()  # (96, 144)

    # Only include CH₄ (index 1) and Black Carbon (index 3)
    selected_indices = [1, 3]
    variables = ["CH₄", "Black Carbon"]

    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    for i, idx in enumerate(selected_indices):
        im = axes[i].imshow(input_sample[idx], cmap="coolwarm", origin="lower")
        axes[i].set_title(variables[i])
        fig.colorbar(im, ax=axes[i], shrink=0.6)

    # Plot output TAS
    im = axes[2].imshow(output_sample, cmap="coolwarm", origin="lower")
    axes[2].set_title("Surface Air Temperature (TAS)")
    fig.colorbar(im, ax=axes[2], shrink=0.6)

    plt.suptitle(f"Comparison at Time Step {time_step}, Sample {sample_idx}")
    plt.tight_layout()
    plt.show()

Set random seed, when using pytorch {“run”:“auto”,“display-mode”:“form”}#

Executing set_seed(seed=seed) you are setting the seed

Hide code cell source
# @title Set random seed, when using `pytorch` {"run":"auto","display-mode":"form"}

# @markdown Executing `set_seed(seed=seed)` you are setting the seed

# for DL its critical to set the random seed so that students can have a
# baseline to compare their results to expected results.
# Read more here: https://pytorch.org/docs/stable/notes/randomness.html

# Call `set_seed` function in the exercises to ensure reproducibility.


def set_seed(seed=None, seed_torch=True):
  if seed is None:
    seed = np.random.choice(2 ** 32)
  random.seed(seed)
  np.random.seed(seed)
  if seed_torch:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

  print(f'Random seed {seed} has been set.')


# In case that `DataLoader` is used
def seed_worker(worker_id):
  worker_seed = torch.initial_seed() % 2**32
  np.random.seed(worker_seed)
  random.seed(worker_seed)


set_seed(seed=2021, seed_torch=False)  # change 2021 with any number you like
Random seed 2021 has been set.

Set device (GPU or CPU). Execute set_device()#

Hide code cell source
# @title Set device (GPU or CPU). Execute `set_device()`
# especially if torch modules used.

# inform the user if the notebook uses GPU or CPU.

def set_device():
  device = "cuda" if torch.cuda.is_available() else "cpu"
  if device != "cuda":
    print("GPU is not enabled in this notebook. But it will help make training faster if GPU is enabled. \n"
          "If you want to enable it, in the menu under `Runtime` -> \n"
          "`Hardware accelerator.` and select `GPU` from the dropdown menu")
  else:
    print("GPU is enabled in this notebook. \n"
          "If you want to disable it, in the menu under `Runtime` -> \n"
          "`Hardware accelerator.` and select `None` from the dropdown menu")

  return device
device= set_device()
GPU is not enabled in this notebook. But it will help make training faster if GPU is enabled. 
If you want to enable it, in the menu under `Runtime` -> 
`Hardware accelerator.` and select `GPU` from the dropdown menu

Note:
GPU acceleration is optional for this tutorial. All code can be executed on CPUs, though some steps (especially model training) may take longer. Please allow additional time when running on CPU-only environments and meanwhile go through the remaining tutorial.

Video 1: Deep Learning Techniques#

# @title Submit your feedback
content_review(f"{feedback_prefix}_Deep_Learning_Techniques")
If you want to download the slides: https://osf.io/download/abem5/

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Deep_Learning_Techniques")

Section 1. Transitioning to Deep Learning with PyTorch: From Machine Learning to Deep Learning in Climate Data Analysis#

Section 1.1 Why Move from Machine Learning to Deep Learning?#

In our previous tutorials, we used Machine Learning (ML) models, such as Random Forests and Gradient Boosting Machines, to analyze climate data. These models are effective when working with structured, tabular datasets but come with limitations:

1️⃣ Manual Feature Engineering: ML models require carefully selected and engineered features, which may not fully capture hidden patterns in climate data.
2️⃣ Limited Spatial & Temporal Awareness: Climate data is highly spatiotemporal, meaning relationships exist across both space and time—something ML models struggle to capture.
3️⃣ Scalability Issues: ML techniques work well on small to medium-sized datasets, but struggle with high-dimensional and large-scale climate datasets.

🔹 Deep Learning (DL), on the other hand, is designed to overcome these challenges by automatically learning features from raw, unstructured climate data. Unlike ML, DL models can handle large datasets, capture complex dependencies, and extract meaningful insights without the need for manual feature engineering.

🔍 Aspect

Machine Learning (Previous Tutorials)

🚀 Deep Learning (This Tutorial)

Input Data

Structured tabular format

Raw climate data (NetCDF)

Feature Engineering

Manual selection required

Automatic feature extraction

Spatial Awareness

Limited or none

Captures spatial dependencies

Temporal Awareness

Limited

Captures long-term patterns

Scalability

Suitable for small datasets

Efficient for large datasets

Section 1.1.1 ML Input vs. DL Input: What Changes?#

Machine Learning in Previous Tutorials

  • Input: Climate variables from 2015 and projected emissions (2015–2050)

  • Output: Predicted 2050 temperature anomaly

  • Data Format: Tabular representation with location-scenario pairs

  • Limitation: Spatial and temporal dependencies were not explicitly preserved

While ML models performed well, they lacked the ability to capture complex spatiotemporal relationships present in climate data.

Deep Learning for Spatiotemporal Data Deep learning enables us to work with high-dimensional climate data while maintaining its spatial and temporal structure. Instead of using tabular data, we now process climate data in its original NetCDF format, which includes:

  • Variables: CO₂, CH₄, SO₂, BC

  • Dimensions: (time, latitude, longitude)

  • Input: Entire climate maps over time

  • Output: Future climate projections at a grid level

Why This Transition?

🚀 Advantage

🔍 Benefit in Climate Modeling

Retains Spatial Structure

Climate data is naturally spatial—DL can learn patterns across regions

Captures Temporal Trends

Climate events are time-dependent—DL can model long-term patterns

Works with Raw Data

No need for manual feature engineering—model extracts features directly

Uses CNNs & LSTMs

Specialized layers handle both spatial (CNNs) and temporal (LSTMs) relationships


Section 1.2 What is Deep Learning? ## Section 1.2 What is Deep Learning?

Deep Learning (DL) is a specialized branch of Machine Learning that uses Artificial Neural Networks (ANNs) to learn from data in a hierarchical manner. Instead of relying on handcrafted features, DL models extract patterns directly from raw data through multiple processing layers.

Key Components of Deep Learning

🔹 1. Neural Networks

  • Deep learning models are composed of neurons (inspired by the human brain).

  • Neurons are organized into layers—each transforming the input data into meaningful representations.

  • More layers = Deeper learning, hence the term “deep learning”.

🔹 2. Training via Backpropagation

  • The model learns by adjusting weights using a technique called gradient descent.

  • The error is propagated backward to refine the model iteratively.

🔹 3. Deep Learning Architectures

  • Convolutional Neural Networks (CNNs): Ideal for processing spatial climate data (e.g., satellite images).

  • Recurrent Neural Networks (RNNs) & LSTMs: Designed for sequential data (e.g., temperature trends over time).

  • Hybrid CNN-LSTM Models: Capture both spatial and temporal dependencies—perfect for climate prediction.


Section 1.3 Why Deep Learning? ## Section 1.3 Why Deep Learning? * **Automated Feature Extraction:** Deep learning models automatically discover complex relationships from the data, reducing reliance on manual feature engineering. * **Spatiotemporal Modeling:** CNN-LSTMs can simultaneously analyze spatial patterns and temporal dependencies, surpassing the capabilities of simpler models. * **Handles Complex Data:** Able to handle the high dimensional climate data more affectively than previous approaches.
Sectoin 1.4 Why PyTorch? ## Section 1.4 Why PyTorch?

To build our deep learning models, we will use PyTorch—one of the most widely used deep learning frameworks. PyTorch provides flexibility, intuitive coding, and GPU acceleration, making it an excellent choice for research and production applications.

Advantages of PyTorch
Dynamic Computation Graphs: Unlike TensorFlow, PyTorch builds computational graphs dynamically, making debugging easier.
Easy-to-Use API: Simple, Pythonic syntax that integrates seamlessly with NumPy.
Efficient GPU Acceleration: Allows rapid training on GPUs, making deep learning models highly scalable.
Robust Library Ecosystem: Includes built-in modules for automatic differentiation, optimization, and dataset handling.

PyTorch Essentials for This Tutorial

🔧 PyTorch Module

Purpose

torch.Tensor

Core data structure for PyTorch models

torch.nn

Provides layers like CNN, LSTM, etc.

torch.optim

Optimizers for training models

torch.autograd

Automatic differentiation for backpropagation

torch.utils.data

Handles datasets and dataloaders

💡 In this tutorial, we will use PyTorch to implement a CNN-LSTM model for climate prediction, leveraging both spatial and temporal patterns in raw climate data.


Section 1.5 Critique of the Previous ML Tutorial: What Can Be Improved? ## Section 1.5 Critique of the Previous ML Tutorial: What Can Be Improved?

While our previous ML-based approach was effective, it had some limitations that we aim to address with deep learning:

🔹 Limited Generalization: The ML model was trained on a condensed dataset, meaning it might not generalize well to real-world, large-scale climate data.
🔹 Feature Engineering Dependency: The performance of ML models heavily depends on manual feature selection, which is time-consuming and requires domain expertise.
🔹 Inability to Capture Spatial/Temporal Dependencies: Tree-based ML models treat input features as independent variables, ignoring crucial spatial and temporal correlations in climate data.
🔹 Scalability Issues: As climate datasets grow in size, traditional ML methods struggle to handle the increasing data complexity efficiently.

By moving to deep learning, we address these shortcomings by:
✅ Using raw, high-dimensional climate data instead of condensed versions.
✅ Leveraging CNNs and LSTMs to automatically learn patterns from spatial and temporal data.
✅ Utilizing GPU-accelerated PyTorch models to efficiently handle large datasets.

--- 🚀 Next, let's load and preprocess our climate dataset for deep learning!

Submit your feedback#

Hide code cell source
# @title Submit your feedback
content_review(f"{feedback_prefix}_Section_1")

Section 2: ClimateBench Data Reloaded - Now in PyTorch!#

In this section, we transition from the pandas and scikit-learn world of Tutorials 1 and 2 to the tensor-centric universe of PyTorch. We’ll load a similar ClimateBench dataset, but prepare it for our CNN-LSTM architecture.

As before, we need a set of tools. Note the key change: we bring in torch and its related modules:

Note: For deateiled understanding of Pytorch and Deep leanring refer to Neuromacth Deep Learning Course’s W1D1Tutorial1

Working with Tensors: PyTorch's Core Data Structure

PyTorch revolves around tensors, which are multi-dimensional arrays similar to NumPy arrays, but with the added benefit of GPU acceleration. Think of a tensor as the fundamental building block for representing data in neural networks.

Why Tensors?

  • GPU Acceleration: Enable lightning-fast computations for complex models.

  • Automatic Differentiation: Seamlessly compute gradients for training.

  • Flexibility: Represent various data types (floats, integers, etc.).

Essential Components for Climate-Informed Deep Learning
  • torch.Tensor: The base data structure for representing climate variables (temperature, emissions, etc.).

  • torch.nn: A module containing building blocks for defining our CNN-LSTM model architecture (convolutional layers, LSTM layers, etc.).

  • torch.optim: Optimization algorithms (e.g., Adam) to train the model effectively.

  • torch.utils.data.Dataset & torch.utils.data.DataLoader: Powerful tools for managing large climate datasets and efficiently feeding them into our model during training.


Pytorch core Component Breakdown Table

Component

Symbol

Purpose

Climate Application Example

Tensors

GPU-accelerated multidimensional arrays

Store 3D atmospheric data cubes

nn.Module

🧱

Neural network building blocks

Create CNN-LSTM hybrid architectures

Optimizers

🎯

Parameter update strategies

Adam for stable climate model training

DataLoaders

📂

Batch processing & shuffling

Handle decades of climate observations

Loss Functions

📉

Model performance quantification

MSE for temperature prediction


🔄 Workflow Insight: Typical development pattern: 1. Tensor Preparation → 2. Model Architecture → 3. Loss/Optimizer Setup → 4. Training Loop → 5. Validation

Section 2.1 The Shift to Spatiotemporal Data: ClimateBench in Native Format#

In Tutorials 1 and 2, we trained machine learning models using a simplified, spatially-averaged dataset. While this approach was useful, it had limitations:

  • Loss of spatial information, reducing the model’s ability to capture regional climate variations

  • Limited temporal structure, as time-series emissions were flattened into tabular form

Now, we transition to deep learning, unlocking the full potential of the ClimateBench dataset by preserving its original spatial and temporal structure.


Recap: What Was the Previous Data Format?

What Was the Previous Data Format?#

Previously, we averaged across spatial dimensions and flattened the emissions time series, resulting in:

  • Shape: (3240, 152)

    • 3240 rows → location-scenario combinations

    • 152 columns → 2015 climate variables + time-averaged emissions

This simplified dataset was easier to process with scikit-learn, but it sacrificed critical spatial and temporal dependencies.


Recap: What Data Are We going to Use Now?

Recap: What Data Are We going to Use Now?#

We now work with the original NetCDF structure, which explicitly retains all spatial and temporal information. This dataset will be structured as:

  • Input (X)(766, 10, 96, 144, 4)

    • 766 sequences → extracted from all climate simulations

    • 10 time steps → sliding window approach over years

    • 96 latitude points → range: -90° to 90°

    • 144 longitude points → range: to 357.5° (2.5° increments)

    • 4 climate variables → CO₂, CH₄, SO₂, BC

  • Target (Y)(766, 1, 96, 144)

    • Single time-step temperature anomaly prediction

This structure allows deep learning models to capture spatial dependencies (across latitude and longitude) and learn temporal trends (using recurrent or convolutional layers).


Overall#

In Tutorials 1 & 2, we worked with a pre-processed, tabular dataset optimized for scikit-learn.
Now, we will directly load and process the raw ClimateBench dataset using xarray. This ensures that our deep learning models can fully leverage the spatiotemporal structure of climate data.

💡 Next Step: Let’s dive into the code and see how to load and preprocess this data! 🚀

First, Define the path to the training data and then define the climate scenarios

Section 2.1.1 Set Data path#

data_path = "Data/train_val/" #Path to the data whihc is temporarily loaded to colab RAM
#define the climate scenarios of interest
simus = ['ssp126', 'ssp370', 'ssp585', 'hist-GHG', 'hist-aer']
len_historical = 165 # Length of the historical period