# Copyright 2026 IPSL / CNRS / Sorbonne University
# Authors: Kishanthan Kingston
#
# This work is licensed under the Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc-sa/4.0/
import argparse
import gc
from pathlib import Path
import numpy as np
import xarray as xr
from AID_BC.logger import Logger
from AID_BC.quantile_mapping import QM
[docs]
def parse_args():
"""
Parse command-line arguments.
Returns
-------
argparse.Namespace
Parsed arguments.
"""
parser = argparse.ArgumentParser(
description="Quantile Mapping bias correction using preprocessed Zarr data"
)
# First year used to train the QM correction
parser.add_argument(
"--train_start", type=int, required=True, help="Training start year"
)
# Last year used to train the QM correction
parser.add_argument(
"--train_end", type=int, required=True, help="Training end year"
)
parser.add_argument(
"--apply_year", type=int, required=True, help="Application year"
)
parser.add_argument("--variable", type=str, default="VAR_2T", help="Variable name")
parser.add_argument(
"--era5_root", type=str, required=True, help="ERA5 root directory"
)
parser.add_argument(
"--cmip6_train_zarr",
type=str,
required=True,
help="Preprocessed CMIP6 training Zarr path",
)
parser.add_argument(
"--cmip6_apply_zarr",
type=str,
required=True,
help="Preprocessed CMIP6 application Zarr path",
)
parser.add_argument(
"--output_dir", type=str, required=True, help="Output directory"
)
# Number of latitude points processed at once
parser.add_argument(
"--chunk_lat", type=int, default=144, help="Latitude chunk size"
)
# Number of longitude points processed at once
parser.add_argument(
"--chunk_lon", type=int, default=360, help="Longitude chunk size"
)
return parser.parse_args()
[docs]
def build_era5_paths(start_year, end_year, era5_root):
"""
Build ERA5 file paths for the training period.
Parameters
----------
start_year : int
Training start year.
end_year : int
Training end year.
era5_root : str
ERA5 root directory.
Returns
-------
list[str]
ERA5 file paths.
"""
# ERA5 files are expected to follow the naming convention samples_<year>.nc
# same as IPSL-AID
return [
str(Path(era5_root) / f"samples_{year}.nc")
for year in range(start_year, end_year + 1)
]
[docs]
def iter_spatial_chunks(n_lat, n_lon, chunk_lat, chunk_lon):
"""
Iterate over spatial chunks.
Parameters
----------
n_lat : int
Number of latitude points.
n_lon : int
Number of longitude points.
chunk_lat : int
Latitude chunk size.
chunk_lon : int
Longitude chunk size.
Returns
------
tuple[slice, slice]
Latitude and longitude slices defining one spatial chunk.
"""
# Loop over latitude indices by chunk
for lat_start in range(0, n_lat, chunk_lat):
# Ensure the last latitude chunk does not exceed the grid size
lat_end = min(lat_start + chunk_lat, n_lat)
# Loop over longitude indices by chunk
for lon_start in range(0, n_lon, chunk_lon):
# Ensure the last longitude chunk does not exceed the grid size
lon_end = min(lon_start + chunk_lon, n_lon)
# Return slices defining the current spatial chunk
yield (slice(lat_start, lat_end), slice(lon_start, lon_end))
[docs]
def apply_qm_by_spatial_chunks(
Y_train, X_train, X_apply, variable_name, chunk_lat, chunk_lon, logger
):
"""
Apply Quantile Mapping chunk by chunk.
Parameters
----------
Y_train : xr.DataArray
ERA5 reference training data.
X_train : xr.DataArray
Preprocessed CMIP6 training data on ERA5 grid.
X_apply : xr.DataArray
Preprocessed CMIP6 application data on ERA5 grid.
variable_name : str
Variable name.
chunk_lat : int
Latitude chunk size.
chunk_lon : int
Longitude chunk size.
logger : Logger
Logger instance.
Returns
-------
corr : xr.DataArray
Bias-corrected application data.
"""
# Ensure all input arrays use the same dimension order
Y_train = Y_train.transpose("time", "latitude", "longitude")
X_train = X_train.transpose("time", "latitude", "longitude")
X_apply = X_apply.transpose("time", "latitude", "longitude")
# Get the number of latitude and longitude points
n_lat = X_apply.sizes["latitude"]
n_lon = X_apply.sizes["longitude"]
# Allocate the full output array that will store corrected values
Z_apply = np.empty(X_apply.shape, dtype=np.float32)
# Compute the total number of chunks
total_chunks = int(np.ceil(n_lat / chunk_lat)) * int(np.ceil(n_lon / chunk_lon))
chunk_id = 0
# Process each spatial chunk independently to reduce memory usage
for lat_slice, lon_slice in iter_spatial_chunks(
n_lat=n_lat, n_lon=n_lon, chunk_lat=chunk_lat, chunk_lon=chunk_lon
):
chunk_id += 1
logger.info(
f"Processing spatial chunk "
f"{chunk_id}/{total_chunks} | "
f"lat={lat_slice}, lon={lon_slice}"
)
logger.info("Reading ERA5 training chunk")
# Read the ERA5 reference data for the current spatial chunk
Y_chunk = Y_train.isel(latitude=lat_slice, longitude=lon_slice).values.astype(
np.float32
)
logger.info("Reading CMIP6 training chunk")
# Read the CMIP6 training data for the same spatial chunk
X_train_chunk = X_train.isel(
latitude=lat_slice, longitude=lon_slice
).values.astype(np.float32)
logger.info("Reading CMIP6 application chunk")
# Read the CMIP6 data to be corrected for the same spatial chunk
X_apply_chunk = X_apply.isel(
latitude=lat_slice, longitude=lon_slice
).values.astype(np.float32)
# Reshape data from 3D: time x latitude x longitude
# to 2D: time x grid_points, as required by the QM model
Y_train_2D = Y_chunk.reshape(Y_chunk.shape[0], -1)
X_train_2D = X_train_chunk.reshape(X_train_chunk.shape[0], -1)
X_apply_2D = X_apply_chunk.reshape(X_apply_chunk.shape[0], -1)
logger.info(f"Y_train_2D shape: {Y_train_2D.shape}")
logger.info(f"X_train_2D shape: {X_train_2D.shape}")
logger.info(f"X_apply_2D shape: {X_apply_2D.shape}")
# Create a new Quantile Mapping model for this spatial chunk
qm = QM()
# Fit Quantile Mapping using ERA5 as reference and CMIP6 as model dat
qm.fit(Y0=Y_train_2D, X0=X_train_2D)
# Apply the fitted Quantile Mapping model to the application data
Z_chunk_2D = qm.predict(X0=X_apply_2D)
# Reshape the corrected data back to the original 3D chunk shape
Z_chunk = Z_chunk_2D.astype(np.float32).reshape(X_apply_chunk.shape)
# Insert the corrected chunk into the full output array
Z_apply[:, lat_slice, lon_slice] = Z_chunk
# delete temporary arrays to reduce memory usage
del Y_chunk
del X_train_chunk
del X_apply_chunk
del Y_train_2D
del X_train_2D
del X_apply_2D
del Z_chunk_2D
del Z_chunk
del qm
# Force garbage collection after each chunk
gc.collect()
# Create a corrected xarray DataArray using the metadata of the input data
corr = X_apply.copy(data=Z_apply)
# Assign the variable name to the corrected DataArray
corr.name = variable_name
# Return the bias-corrected application data
return corr
[docs]
def main():
"""
Main Quantile Mapping workflow.
"""
args = parse_args()
logger = Logger()
train_years = args.train_end - args.train_start + 1
logger.info(f"Opening ERA5 training data " f"({args.train_start}-{args.train_end})")
# Load ERA5 reference files for the full training period
era5_paths = build_era5_paths(
start_year=args.train_start, end_year=args.train_end, era5_root=args.era5_root
)
era5_train_ds = xr.open_mfdataset(
era5_paths, combine="nested", concat_dim="time", engine="netcdf4", cache=False
)
Y_train = era5_train_ds[args.variable]
logger.info(
f"Opening preprocessed CMIP6 training Zarr:\n" f"{args.cmip6_train_zarr}"
)
# Load preprocessed CMIP6 data used to train the correction
cmip6_train_ds = xr.open_zarr(args.cmip6_train_zarr)
X_train = cmip6_train_ds[args.variable]
logger.info(
f"Opening preprocessed CMIP6 application Zarr:\n" f"{args.cmip6_apply_zarr}"
)
# Load preprocessed CMIP6 data for the year that will be bias-corrected
cmip6_apply_ds = xr.open_zarr(args.cmip6_apply_zarr)
X_apply = cmip6_apply_ds[args.variable]
logger.info(f"ERA5 training shape : {Y_train.shape}")
logger.info(f"CMIP6 training shape : {X_train.shape}")
logger.info(f"CMIP6 application shape : {X_apply.shape}")
# ERA5 reference and CMIP6 training data must be aligned in time and space
if Y_train.shape != X_train.shape:
raise ValueError(
"ERA5 training data and CMIP6 training data "
f"do not have the same shape: "
f"{Y_train.shape} != {X_train.shape}"
)
logger.info(
f"Applying chunked Quantile Mapping "
f"trained on {train_years} year(s) "
f"({args.train_start}-{args.train_end})"
)
# Apply Quantile Mapping correction to the application year
corr = apply_qm_by_spatial_chunks(
Y_train=Y_train,
X_train=X_train,
X_apply=X_apply,
variable_name=args.variable,
chunk_lat=args.chunk_lat,
chunk_lon=args.chunk_lon,
logger=logger,
)
output_dir = Path(args.output_dir)
# Build the output NetCDF file path
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Save the corrected data using the same file naming convention as the inputs
output_file = output_dir / f"samples_{args.apply_year}.nc"
logger.info(f"Saving corrected dataset to:\n{output_file}")
# Save the corrected data as a NetCDF file
corr.to_netcdf(output_file)
logger.success("Corrected dataset successfully saved")
# Close opened datasets to release file handles and free resources.
era5_train_ds.close()
cmip6_train_ds.close()
cmip6_apply_ds.close()
if __name__ == "__main__":
main()