# Copyright (c) 2021 CNES/JPL
#
# All rights reserved. Use of this source code is governed by a
# BSD-style license that can be found in the LICENSE file.
"""
Data handler
============
Set of classes that helps developping plugins for ssh and swh. This includes a
default netcdf reader for loading a dataset, and default interpolators for both
regular and irregular grids.
"""
import abc
import datetime
import logging
import os
import re
import time
import dask.array as da
import numba as nb
import numpy as np
import pyinterp
import pyinterp.backends.xarray
import xarray as xr
from . import Interface
#: Module logger
LOGGER = logging.getLogger(__name__)
[docs]class DatasetLoader:
"""Interface that specializes in loading data for the plugin. This is
helpful to separate the data loading and interpolator definition. The data
loader has only one task:
* Given a time range for interpolation.
* Loads the model data set needed to perform the interpolation.
* Transforms it to have canonical variable names.
"""
[docs] @abc.abstractmethod
def load_dataset(self, first_date: np.datetime64,
last_date: np.datetime64) -> xr.Dataset:
"""Loads the data under the form of a xr.Dataset. The loaded dataset
should contain values that allow interpolating first_date and
last_date. This means its time interval is a little large than.
[first_date, last_date].
Moreover, the dataset should refer to the longitude, latitude, time and
sea surface height using canonical names: lon, lat, time, ssh
Args:
first_date (numpy.datetime64): The first date that needs to be
interpolated
last_date (numpy.datetime64): The last date that needs to be
interpolated
Returns:
xr.Dataset: dataset containing lon, lat, time and ssh variables,
with canonical names.
See also:
:py:meth:`DatasetLoader._shift_date`
"""
...
[docs] @staticmethod
def _shift_date(date: np.datetime64, shift: int,
time_delta: np.timedelta64) -> np.datetime64:
"""Shift the input date using the time_delta of original data. This is
useful to generate a time interval for which we need an original value.
Args:
date (np.datetime64): interpolation date
shift (int): 1 for a later date, -1 for an earlier one
time_delta (np.timedelta64): delta specifying the time resolution
of the model data.
Returns:
The input date if it is the input date is a multiple of time_delta
(meaning it is on the model time axis). Else, the output is shifted.
Example:
If we have data on [t0, t1, dt], and we want an interpolation over
[T0, T1], then we must make sure that t0 <= T0 - dt and t1 >= T1 +
dt. If this condition is satisfied, interpolation at T0 and T1 will
be possible. If this condition is not satisfied, interpolation
becomes extrapolation.
"""
# Before comparing the date and timedelta, ensure they have the same
# unit
date_same_unit = date + time_delta - time_delta
time_delta_same_unit = date + time_delta - date
if date_same_unit.astype("int64") % time_delta_same_unit.astype(
"int64") != 0: # type: ignore
return date + time_delta * shift
return date
[docs] @staticmethod
def _calculate_time_delta(dates: xr.DataArray) -> np.timedelta64:
"""Calculation of the delta T between two consecutive grids.
Args:
dates (xr.DataArray): dates of the model data.
Returns:
np.timedelta64: the time delta between two consecutive maps.
"""
frequency = np.diff(dates)
try:
if not np.all(frequency == frequency[0]):
raise RuntimeError(
"Time series does not have a constant step between two "
f"grids: {set(frequency)} seconds")
return np.timedelta64(frequency[0], "ns")
except IndexError as exc:
raise RuntimeError(
"Check that your list of data is not empty") from exc
[docs]@nb.njit("(float32[::1])(int64[::1], float32[:, ::1], int64[::1])",
cache=True,
nogil=True) # type:ignore
def time_interp(xp: np.ndarray, yp: np.ndarray, xi: np.ndarray) -> np.ndarray:
"""Time interpolation for the different spatial grids interpolated on the
SWOT data.
Args:
xp (numpy.ndarray): The x-coordinates of the SWOT grid.
yp (numpy.ndarray): The y-coordinates of the SWOT grid, same shape as
xp.
xi (numpy.ndarray): The x-coordinates of the interpolation grid.
Returns:
numpy.ndarray: The interpolated values.
"""
xp_diff = np.diff(xp)
assert xp.shape[0] == yp.shape[0] and yp.shape[1] == xi.shape[0]
assert np.all(xp_diff == xp_diff[0])
result = np.empty(yp.shape[1:], dtype=yp.dtype)
step = 1.0 / xp_diff[0]
size = xp.size
for ix in range(yp.shape[0]):
index = int(np.around((xi[ix] - xp[0]) * step))
assert index >= 0 or index <= size
if index == size - 1:
i0 = index - 1
i1 = index
else:
i0 = index
i1 = index + 1
t0 = xp[i0]
t1 = xp[i1]
dt = t1 - t0
w0 = (t1 - xi[ix]) / dt
w1 = (xi[ix] - t0) / dt
for jx in range(yp.shape[1]):
result[jx] = (w0 * yp[i0, jx] + w1 * yp[i1, jx]) / (w0 + w1)
return result
[docs]class NetcdfLoader(DatasetLoader):
"""Plugin that implements a netcdf reader. The netcdf reader works on files
whose names have the date in it. A pattern (ex. P(?<date>.*).nc),
associated with a date formatter (ex. %Y%m%d) is used to get build the time
series.
Netcdf files can be expensive to concatenate if there are a lot of
files. This loader avoid loading too much files by building a
dictionary matching file paths to their time. During the
interpolation, where only a given time period is needed, only the
files that cover the time period are loaded in the dataset.
"""
[docs] def __init__(self,
path: str,
date_fmt: str,
lon_name: str = "lon",
lat_name: str = "lat",
ssh_name: str = "ssh",
time_name: str = "time",
pattern: str = ".nc"):
"""Initialization of the netcdf loader.
Args:
path (str): Folder containing the netcdf files
date_fmt (str): date formatter
lon_name (str): longitude name in the netcdf files. Defaults to
'lon'
lat_name (str): latitude name in the netcdf files. Defaults to 'lat'
ssh_name (str): sea surface height name in the netcdf files.
Defaults to 'ssh'
time_name (str): time name in the netcdf files. Defaults to 'time'
pattern (str): Pattern for the NetCDF file names. It should contain
the P(?<date>) group to retrieve the time
Example:
If we have netcdf files whose names are ``model_20120305_12h.nc``,
we must define the following to retrieve the time:
.. code-block:: python
loader = NetcdfLoader(
'.',
pattern='model_P(?<date>\\w+).nc',
date_fmt='%Y%m%d_%Hh'
)
"""
if not os.path.exists(path):
raise FileNotFoundError(f"{path!r}")
self.lon_name = lon_name
self.lat_name = lat_name
self.ssh_name = ssh_name
self.time_name = time_name
self.regex = re.compile(pattern).search
self.date_fmt = date_fmt
self.time_series = self._walk_netcdf(path)
self.time_delta = self._calculate_time_delta(self.time_series["date"])
[docs] def _walk_netcdf(self, path: str) -> np.ndarray:
"""Browse the NetCDF grids in the directory to create the time series
constituted by these files (a file contains a time step).
Args:
path (str): Folder containing the netcdf files
Returns:
numpy.ndarray: Time series of the netcdf files.
"""
# Walks a netcdf folder and finds data files in it
items = []
length = -1
for dir_path, _, filenames in os.walk(path):
for filename in filenames:
match = self.regex(filename)
if match:
time_counter = np.datetime64(
datetime.datetime.strptime(match.group("date"),
self.date_fmt))
filepath = os.path.join(dir_path, filename)
items.append((time_counter, filepath))
length = max(length, len(filepath))
# The time series is encoded in a structured Numpy array containing
# the date and path to the file.
time_series = np.array(
items,
dtype={
"names": ("date", "path"),
"formats": ("datetime64[s]", f"U{length}"),
},
)
time_series = time_series[np.argsort(time_series["date"])]
return time_series
[docs] def select_netcdf_files(self, first_date: np.datetime64,
last_date: np.datetime64) -> np.ndarray:
"""Selects the netcdf files that cover the time period.
Args:
first_date (numpy.datetime64): first date of the time period
last_date (numpy.datetime64): last date of the time period
Returns:
numpy.ndarray: Array containing the paths to the netcdf files that
cover the time period.
"""
first_date = self._shift_date(first_date.astype("datetime64[ns]"), -1,
self.time_delta)
last_date = self._shift_date(last_date.astype("datetime64[ns]"), 1,
self.time_delta)
if first_date < self.time_series["date"][
0] or last_date > self.time_series["date"][-1]:
raise IndexError(
f"period [{first_date}, {last_date}] is out of range: "
f"[{self.time_series['date'][0]}, "
f"{self.time_series['date'][-1]}]")
selected = np.logical_and(self.time_series["date"] >= first_date,
self.time_series["date"] < last_date)
return selected
[docs] def load_dataset(self, first_date: np.datetime64,
last_date: np.datetime64):
"""Loads the dataset between the given dates.
Args:
first_date (numpy.datetime64): first date to load.
last_date (numpy.datetime64): last date to load.
Returns:
xarray.Dataset: the dataset loaded.
"""
LOGGER.debug("fetch %s for %s, %s", self.__class__.__name__,
first_date, last_date)
selected = self.select_netcdf_files(first_date, last_date)
dataset = xr.open_mfdataset(self.time_series["path"][selected],
concat_dim=self.time_name,
combine="nested")
if self.time_name not in dataset.coords:
LOGGER.debug(
"Time coordinate %s was not found, assigning "
"axis with time from file names", self.time_name)
dataset = dataset.assign_coords(
{self.time_name: self.time_series["dates"][selected]})
return dataset.rename({
self.lon_name: "lon",
self.lat_name: "lat",
self.ssh_name: "ssh",
self.time_name: "time"
})
[docs]class IrregularGridHandler(Interface):
"""Default interpolator for an irregular grid. First, uses an RTree to do
the spatial interpolation of all model grid, then do the time interpolation
with a simple weighting of two grid.
Args:
dataset_loader (DataLoader): Data loader
"""
[docs] def __init__(self, dataset_loader: DatasetLoader):
self.dataset_loader = dataset_loader
[docs] def interpolate(self, lon: np.ndarray, lat: np.ndarray,
dates: np.ndarray) -> np.ndarray:
"""Interpolate the SSH for the given coordinates.
Args:
lon (numpy.ndarray): longitude coordinates
lat (numpy.ndarray): latitude coordinates
dates (numpy.ndarray): dates
Returns:
numpy.ndarray: interpolated SSH
"""
# Spatial interpolation of the SSH on the different selected grids.
dataset = self.dataset_loader.load_dataset(
dates.min(), # type: ignore
dates.max()) # type: ignore
dates_p = dataset.time.load().data
lon_p = dataset.lon.data
lat_p = dataset.lat.data
ssh_p = dataset.ssh.data
assert ssh_p.shape[0] == len(dates_p)
start_time = time.time()
layers = []
for index in range(len(ssh_p)):
layers.append(
self._spatial_interp(ssh_p[index, :], lon_p, lat_p, lon, lat))
layers = np.stack(layers)
LOGGER.debug("interpolation completed in %.2fs for period %s, %s",
time.time() - start_time, dates.min(), dates.max())
# Time interpolation of the SSH.
return time_interp(
dates_p.astype("datetime64[us]").astype("int64"),
layers,
dates.astype("datetime64[us]").astype("int64"),
)
[docs] @staticmethod
def _spatial_interp(z_model: da.Array, x_model: da.Array,
y_model: da.Array, x_sat: np.ndarray,
y_sat: np.ndarray) -> np.ndarray:
"""Spatial interpolation of the SSH on the selected maps.
Args:
z_model (numpy.ndarray): model SSH
x_model (numpy.ndarray): model longitude
y_model (numpy.ndarray): model latitude
x_sat (numpy.ndarray): satellite longitude
y_sat (numpy.ndarray): satellite latitude
Returns:
numpy.ndarray: interpolated SSH in space.
"""
mesh = pyinterp.RTree()
mesh.packing(
np.vstack((x_model.compute(), y_model.compute())).T,
z_model.compute())
z, _ = mesh.radial_basis_function(
np.vstack((x_sat, y_sat)).T.astype("float32"),
within=True,
k=11,
rbf="thin_plate",
num_threads=1,
)
return z.astype("float32")
[docs]class CartesianGridHandler(Interface):
"""Default interpolator for regular grid. Uses
pyinterp.backends.xarray.Grid3D.trivariate interpolator.
Args:
dataset_loader (DatasetLoader): DatasetLoader object
"""
[docs] def __init__(self, dataset_loader: DatasetLoader):
self.dataset_loader = dataset_loader
[docs] def interpolate(self, lon: np.ndarray, lat: np.ndarray,
dates: np.ndarray) -> np.ndarray:
"""Interpolate the SSH to the required coordinates.
Args:
lon (numpy.ndarray): longitude coordinates
lat (numpy.ndarray): latitude coordinates
dates (numpy.ndarray): dates of the simulated measurements
"""
dataset = self.dataset_loader.load_dataset(
dates.min(), # type: ignore
dates.max()) # type: ignore
interpolator = pyinterp.backends.xarray.Grid3D(dataset.ssh)
ssh = interpolator.trivariate(dict(longitude=lon,
latitude=lat,
time=dates),
bounds_error=True,
interpolator="bilinear")
return ssh