Source code for swot_simulator.plugins.ssh.mitgcm

# Copyright (c) 2020 CNES/JPL
#
# All rights reserved. Use of this source code is governed by a
# BSD-style license that can be found in the LICENSE file.
"""
Interpolate SSH from MIT/GCM model
==================================
"""
import logging
import time
import dask.array as da
import numba as nb
import numpy as np
import pyinterp
import xarray as xr
from . import detail

LOGGER = logging.getLogger(__name__)


@nb.njit('(float32[::1])(int64[::1], float32[:, ::1], int64[::1])',
         cache=True,
         nogil=True)
def _time_interp(xp: np.ndarray, yp: np.ndarray, xi: np.ndarray) -> np.ndarray:
    """Time interpolation for the different spatial grids interpolated on the
    SWOT data"""
    xp_diff = np.diff(xp)

    assert xp.shape[0] == yp.shape[0] and yp.shape[1] == xi.shape[0]
    assert np.all(xp_diff == xp_diff[0])

    result = np.empty(yp.shape[1:], dtype=yp.dtype)

    step = 1.0 / xp_diff[0]
    size = xp.size

    for ix in range(yp.shape[0]):
        index = int(np.around((xi[ix] - xp[0]) * step))
        assert index >= 0 or index <= size
        if index == size - 1:
            i0 = index - 1
            i1 = index
        else:
            i0 = index
            i1 = index + 1
        t0 = xp[i0]
        t1 = xp[i1]
        dt = t1 - t0
        w0 = (t1 - xi[ix]) / dt
        w1 = (xi[ix] - t0) / dt

        for jx in range(yp.shape[1]):
            result[jx] = (w0 * yp[i0, jx] + w1 * yp[i1, jx]) / (w0 + w1)

    return result


def _spatial_interp(z_model: da.array, x_model: da.array, y_model: da.array,
                    x_sat: np.ndarray, y_sat: np.ndarray):
    mesh = pyinterp.RTree(dtype="float32")
    x, y, z = (), (), ()

    start_time = time.time()

    for face in range(13):
        x_face = x_model[face, :].compute()
        y_face = y_model[face, :].compute()

        # We test if the face covers the satellite positions.
        ix0, ix1 = x_face.min(), x_face.max()
        iy0, iy1 = y_face.min(), y_face.max()

        box = pyinterp.geodetic.Box2D(pyinterp.geodetic.Point2D(ix0, iy0),
                                      pyinterp.geodetic.Point2D(ix1, iy1))
        mask = box.covered_by(x_sat, y_sat)
        if not np.any(mask == 1):
            continue
        del box, mask

        # The undefined values are filtered
        z_face = z_model[face, :].compute()
        defined = ~np.isnan(z_face)
        x += (x_face[defined].flatten(), )
        y += (y_face[defined].flatten(), )
        z += (z_face[defined].flatten(), )

    # The tree is built and the interpolation is calculated
    x = np.concatenate(x)
    y = np.concatenate(y)
    coordinates = np.vstack((x, y)).T
    del x, y

    z = np.concatenate(z)
    LOGGER.debug("loaded %d MB in %.2fs",
                 (coordinates.nbytes + z.nbytes) // 1024**2,
                 time.time() - start_time)
    start_time = time.time()
    mesh.packing(coordinates, z)
    LOGGER.debug("mesh build in %.2fs", time.time() - start_time)

    del coordinates, z

    start_time = time.time()
    z, _ = mesh.radial_basis_function(np.vstack(
        (x_sat, y_sat)).T.astype("float32"),
                                      within=True,
                                      k=11,
                                      radius=55000,
                                      rbf="thin_plate",
                                      num_threads=1)
    LOGGER.debug("interpolation done in %.2fs", time.time() - start_time)
    del mesh
    return z.astype("float32")


[docs]class MITGCM(detail.Interface):
[docs] def __init__(self, xc: xr.DataArray, yc: xr.DataArray, eta: xr.DataArray): self.lon = xc.data self.lat = yc.data self.ssh = eta.data self.ts = eta.time.data.astype("datetime64[us]") self.dt = self._calculate_dt(self.ts)
[docs] @staticmethod def _calculate_dt(dates: xr.DataArray): """Calculation of the delta T between two consecutive grids""" frequency = np.diff(dates) if not np.all(frequency == frequency[0]): raise RuntimeError( "Time series does not have a constant step between two " f"grids: {set(frequency)} seconds") return frequency[0]
[docs] def _grid_date(self, date: np.datetime64, shift: int): """Calculates the grid date immediately before or after the date provided""" if date.astype("int64") % self.dt.astype("int64") != 0: return date + self.dt * shift return date
[docs] def interpolate(self, lon: np.ndarray, lat: np.ndarray, dates: np.ndarray) -> np.ndarray: """Interpolate the SSH for the given coordinates""" first_date = self._grid_date(dates[0], -1) last_date = self._grid_date(dates[-1], 1) if first_date < self.ts[0] or last_date > self.ts[-1]: raise IndexError( f"period [{first_date}, {last_date}] is out of range: " f"[{self.ts[0]}, {self.ts[-1]}]") # Mask for selecting data covering the time period provided. mask = (self.ts >= first_date) & (self.ts <= last_date) LOGGER.debug("fetch data for %s, %s", first_date, last_date) # 4D cube representing the data necessary for interpolation. frame = self.ssh[mask] # Spatial interpolation of the SSH on the different selected grids. start_time = time.time() layers = [] for index in range(len(frame)): layers.append( _spatial_interp(frame[index, :], self.lon, self.lat, lon, lat)) # Time interpolation of the SSH. layers = np.stack(layers) LOGGER.debug("interpolation completed in %.2fs for period %s, %s", time.time() - start_time, first_date, last_date) return _time_interp(self.ts[mask].astype("int64"), layers, dates.astype("datetime64[us]").astype("int64"))