from __future__ import annotations
from typing import Any
import numpy as np
from sasktran2.atmosphere import Atmosphere
from sasktran2.optical.base import OpticalProperty
from sasktran2.util.interpolation import linear_interpolating_matrix
from .base import Constituent
[docs]
class GaussianHeightExtinction(Constituent):
[docs]
def __init__(
self,
optical_property: OpticalProperty,
height_m: float,
width_fwhm_m: float,
vertical_optical_depth: float,
vertical_optical_depth_wavel_nm: float,
altitudes_m: np.array,
out_of_bounds_mode: str = "zero",
**kwargs,
) -> None:
"""
A constituent that is defined by a gaussian-shaped extinction profile.
Parameters
----------
optical_property : OpticalProperty
The optical property defining the scattering information
height_m : float
Height of the centre of the gaussian extinction profile in [m]
width_fwhm_m : float
FWHM of the gaussian extinction profile in [m]
vertical_optical_depth : float
Vertical optical depth
vertical_optical_depth_wavel_nm
Wavelength that the vertical optical depth is specified at
altitudes_m : np.array
The altitude grid in [m] over which the optical depth is calculated, as well as the grid for any optical property arguments passed in through kwargs.
out_of_bounds_mode : str, optional
Interpolation mode for outside of the boundaries of the altitude grid, "extend" and "zero" are supported, by default "zero"
kwargs : dict
Additional arguments to pass to the optical property.
"""
super().__init__()
self._out_of_bounds_mode = out_of_bounds_mode
self._altitudes_m = altitudes_m
self._optical_property = optical_property
# Save inputs as array datatype so that they are mutable
self._height_m = np.array(height_m, dtype=float)
self._width_fwhm_m = np.array(width_fwhm_m, dtype=float)
self._vertical_optical_depth = np.array(vertical_optical_depth, dtype=float)
self._vertical_optical_depth_wavel_nm = vertical_optical_depth_wavel_nm
self._kwargs = kwargs
def __getattr__(self, __name: str) -> Any:
if __name in self.__dict__.get("_kwargs", {}):
return self._kwargs[__name]
return None
def __setattr__(self, __name: str, __value: Any) -> None:
if __name in self.__dict__.get("_kwargs", {}):
self._kwargs[__name] = __value
else:
super().__setattr__(__name, __value)
@property
def height_m(self):
return self._height_m
@height_m.setter
def height_m(self, height_m: np.ndarray):
self._height_m = height_m
@property
def width_fwhm_m(self):
return self._width_fwhm_m
@width_fwhm_m.setter
def width_fwhm_m(self, width_fwhm_m: np.ndarray):
self._width_fwhm_m = width_fwhm_m
@property
def vertical_optical_depth(self):
return self._vertical_optical_depth
@vertical_optical_depth.setter
def vertical_optical_depth(self, vertical_optical_depth: np.ndarray):
self._vertical_optical_depth = vertical_optical_depth
def add_to_atmosphere(self, atmo: Atmosphere):
interp_matrix = linear_interpolating_matrix(
self._altitudes_m,
atmo.model_geometry.altitudes(),
self._out_of_bounds_mode,
)
interped_kwargs = {k: interp_matrix @ v for k, v in self._kwargs.items()}
self._xs_at_wavel = self._optical_property.cross_sections(
np.array([self._vertical_optical_depth_wavel_nm]),
altitudes_m=self._altitudes_m,
**self._kwargs,
).extinction.flatten()
# Unnormalized gaussian since we will normalize to vertical optical depth anyways
self._gaussian = np.exp(
-4
* np.log(2)
* (self._altitudes_m - self._height_m) ** 2
/ self._width_fwhm_m**2
)
self._optical_quants = self._optical_property.atmosphere_quantities(
atmo, **interped_kwargs
)
self._gaussian_od = np.trapezoid(self._gaussian, self._altitudes_m)
self._number_density = (
self._gaussian
* self._vertical_optical_depth
/ self._gaussian_od
/ self._xs_at_wavel
)
self._interp_numden = interp_matrix @ self._number_density
atmo.storage.total_extinction[:] += (
self._optical_quants.extinction * (self._interp_numden)[:, np.newaxis]
)
atmo.storage.ssa[:] += (
self._optical_quants.ssa * (self._interp_numden)[:, np.newaxis]
)
atmo.storage.leg_coeff[:] += (
self._optical_quants.ssa[np.newaxis, :, :]
* (self._interp_numden)[np.newaxis, :, np.newaxis]
* self._optical_quants.leg_coeff
)
# Convert back to SSA for ease of use later in the derivatives
self._optical_quants.ssa[:] /= self._optical_quants.extinction
self._optical_quants.ssa[~np.isfinite(self._optical_quants.ssa)] = 1
def register_derivative(self, atmo: Atmosphere, name):
interp_matrix = linear_interpolating_matrix(
self._altitudes_m,
atmo.model_geometry.altitudes(),
self._out_of_bounds_mode,
)
interped_kwargs = {k: interp_matrix @ v for k, v in self._kwargs.items()}
derivs = {}
# common terms
d_gaussian_d_height = (
self._gaussian
* 8
* np.log(2)
* (self._altitudes_m - self._height_m)
/ self._width_fwhm_m**2
)
d_gaussian_d_width = (
self._gaussian
* 8
* np.log(2)
* (self._altitudes_m - self._height_m) ** 2
/ self._width_fwhm_m**3
)
outer_term = (
self._vertical_optical_depth / self._gaussian_od / self._xs_at_wavel
)
d_extinction = self._optical_quants.extinction
d_ssa = (
self._optical_quants.extinction
* (self._optical_quants.ssa - atmo.storage.ssa)
/ atmo.storage.total_extinction
)
d_leg_coeff = self._optical_quants.leg_coeff - atmo.storage.leg_coeff
scat_factor = (self._optical_quants.ssa * self._optical_quants.extinction) / (
atmo.storage.ssa * atmo.storage.total_extinction
)
# height derivative
h_deriv_mapping = atmo.storage.get_derivative_mapping(f"wf_{name}_height_m")
h1 = d_gaussian_d_height
h2 = (
self._gaussian
* np.trapezoid(d_gaussian_d_height, self._altitudes_m)
/ self._gaussian_od
)
h_deriv_mapping.d_extinction[:] += d_extinction
h_deriv_mapping.d_ssa[:] += d_ssa
h_deriv_mapping.d_leg_coeff[:] += d_leg_coeff
h_deriv_mapping.scat_factor[:] += scat_factor
h_deriv_mapping.interp_dim = f"{name}_height_m"
h_deriv_mapping.interpolator = interp_matrix @ (
(outer_term * (h1 - h2))[:, np.newaxis]
)
# width derivative
w_deriv_mapping = atmo.storage.get_derivative_mapping(f"wf_{name}_width_fwhm_m")
w1 = d_gaussian_d_width
w2 = (
self._gaussian
* np.trapezoid(d_gaussian_d_width, self._altitudes_m)
/ self._gaussian_od
)
w_deriv_mapping.d_extinction[:] += d_extinction
w_deriv_mapping.d_ssa[:] += d_ssa
w_deriv_mapping.d_leg_coeff[:] += d_leg_coeff
w_deriv_mapping.scat_factor[:] += scat_factor
w_deriv_mapping.interp_dim = f"{name}_width_fwhm_m"
w_deriv_mapping.interpolator = interp_matrix @ (
(outer_term * (w1 - w2))[:, np.newaxis]
)
# optical depth derivative
tau_deriv_mapping = atmo.storage.get_derivative_mapping(
f"wf_{name}_vertical_optical_depth"
)
tau_deriv_mapping.d_extinction[:] += d_extinction
tau_deriv_mapping.d_ssa[:] += d_ssa
tau_deriv_mapping.d_leg_coeff[:] += d_leg_coeff
tau_deriv_mapping.scat_factor[:] += scat_factor
tau_deriv_mapping.interp_dim = f"{name}_vertical_optical_depth"
tau_deriv_mapping.interpolator = interp_matrix @ (
(self._gaussian / self._gaussian_od / self._xs_at_wavel)[:, np.newaxis]
)
optical_derivs = self._optical_property.optical_derivatives(
atmo, **interped_kwargs
)
vertical_deriv_factor = 1 / self._xs_at_wavel
d_vertical_deriv_factor = self._optical_property.cross_section_derivatives(
np.array([self._vertical_optical_depth_wavel_nm]),
altitudes_m=self._altitudes_m,
**self._kwargs,
)
for _, val in d_vertical_deriv_factor.items():
# convert from derivative of x to derivative of 1/x
val *= -1 * vertical_deriv_factor**2 # noqa: PLW2901
for key, val in optical_derivs.items():
# Code copied from numdenscatterer.py
deriv_mapping = atmo.storage.get_derivative_mapping(f"wf_{name}_{key}")
deriv_mapping.d_extinction[:] += val.d_extinction
d_extinction_scat = val.d_ssa
# First, the optical property returns back d_scattering extinction in the d_ssa container,
# convert this to d_ssa
deriv_mapping.d_ssa[:] += (
d_extinction_scat - val.d_extinction * self._optical_quants.ssa
) / self._optical_quants.extinction
deriv_mapping.d_leg_coeff[:] += val.d_leg_coeff
if key in d_vertical_deriv_factor:
# Have to make some adjustments
# The change in extinction is adjusted
deriv_mapping.d_extinction[:] += (
self._optical_quants.extinction
/ (interp_matrix @ vertical_deriv_factor)[:, np.newaxis]
* (interp_matrix @ d_vertical_deriv_factor[key])[:, np.newaxis]
)
# Change in single scatter albedo should be invariant whether or not we are
# in extinction space or number density space
# Start with leg_coeff
deriv_mapping.d_leg_coeff[:] += (
self._optical_quants.leg_coeff - atmo.storage.leg_coeff
) * (
1 / self._optical_quants.ssa * deriv_mapping.d_ssa
+ 1 / self._optical_quants.extinction * deriv_mapping.d_extinction
)[
np.newaxis, :, :
]
# Then adjust d_ssa
deriv_mapping.d_ssa[:] *= self._optical_quants.extinction
deriv_mapping.d_ssa[:] += deriv_mapping.d_extinction * (
self._optical_quants.ssa - atmo.storage.ssa
)
deriv_mapping.d_ssa[:] /= atmo.storage.total_extinction
# TODO: The model should probably handle this
norm_factor = deriv_mapping.d_leg_coeff.max(axis=0)
norm_factor[norm_factor == 0] = 1
deriv_mapping.scat_factor[:] = (
self._optical_quants.ssa * self._optical_quants.extinction
) / (atmo.storage.ssa * atmo.storage.total_extinction)
deriv_mapping.d_leg_coeff[:] /= norm_factor[np.newaxis, :, :]
deriv_mapping.scat_factor[:] *= norm_factor
deriv_mapping.interpolator = (
interp_matrix * self._number_density[np.newaxis, :]
)
deriv_mapping.interp_dim = f"{name}_altitude"
return derivs