from __future__ import annotations
from copy import deepcopy
from pathlib import Path
import numpy as np
import xarray as xr
import sasktran2 as sk
from sasktran2.atmosphere import Atmosphere, NativeGridDerivative
from sasktran2.optical.base import OpticalProperty, OpticalQuantities
from sasktran2.polarization import LegendreStorageView
[docs]
class OpticalDatabase(OpticalProperty):
[docs]
def __init__(self, db_filepath: Path) -> None:
"""
An optical property that is defined by a database file. This is just a base class to handle file loading,
derived classes must be used as actual optical properties.
Parameters
----------
db_filepath : Path
Path to the optical database file
Raises
------
OSError
If xarray is not installed
"""
super().__init__()
self._file = db_filepath
try:
import xarray as xr
except ImportError:
msg = "xarray must be installed to use OpticalDatabaseGenericAbsorber"
raise msg from OSError
self._database = xr.open_dataset(self._file)
self._validate_db()
[docs]
class OpticalDatabaseGenericAbsorber(OpticalDatabase):
[docs]
def __init__(self, db_filepath: Path) -> None:
"""
A purely absorbing optical property defined by a database file. The database must contain the following
- xs : The absorption cross section in [m^2]
xs must be a function of either wavelength_nm or wavenumber_cminv, and optionally temperature and pressure.
Parameters
----------
db_filepath : Path
Path to the database file
"""
super().__init__(db_filepath)
self._validate_db()
self._quants = None
self._cached_interp_handler = None
def _validate_db(self):
# Old file format used "temperature" and "pressure" instead of the standard keys
# "temperature_k" and "pressure_pa", if so we rename them
if "temperature" in self._database:
self._database = self._database.rename({"temperature": "temperature_k"})
if "pressure" in self._database:
self._database = self._database.rename({"pressure": "pressure_pa"})
if "xs" not in self._database:
msg = "xs must be defined in the optical database"
raise ValueError(msg)
def optical_derivatives(self, atmo: Atmosphere, **kwargs) -> dict:
derivs = {}
interp_handler = self._construct_interp_handler(atmo, **kwargs)
# Split the interpolators into ones that are 'z' dependent and ones that are not
interp_handler_z = {}
interp_handler_noz = {}
for key, val in interp_handler.items():
if val[0] == "z":
interp_handler_z[key] = val
else:
interp_handler_noz[key] = val
# Get the derivatives of the cross section with respect to the z dependent variables
partial_interp = self._database["xs"].interp(**interp_handler_noz)
for key, val in interp_handler_z.items():
# Interpolate over the other variables
new_interpolants = {k: v for k, v in interp_handler_z.items() if k != key}
partial_interp2 = partial_interp.interp(**new_interpolants)
dT = partial_interp2.diff(key) / partial_interp2[key].diff(key)
interp_index = np.argmax(dT[key].to_numpy() > val[1][:, np.newaxis], axis=1)
if "z" in dT.dims:
dT = dT.isel(
{
"z": xr.DataArray(list(range(len(interp_index))), dims="z"),
key: xr.DataArray(interp_index, dims="z"),
}
)
else:
dT = dT.isel({key: interp_index})
derivs[key] = NativeGridDerivative(d_extinction=dT.to_numpy())
derivs[key].d_extinction[np.isnan(derivs[key].d_extinction)] = 0
return derivs
def _construct_interp_handler(self, atmo: Atmosphere, **kwargs) -> dict:
coords = self._database["xs"].coords
interp_handler = {}
# TODO: this could probably be refactored to iterate over atmosphere properties?
if "temperature_k" in coords:
if atmo.temperature_k is None:
msg = "temperature_k must be specified in Atmosphere to use OpticalDatabaseGenericAbsorber"
raise ValueError(msg)
interp_handler["temperature_k"] = ("z", atmo.temperature_k)
if "pressure_pa" in coords:
if atmo.pressure_pa is None:
msg = "pressure_pa must be specified in Atmosphere to use OpticalDatabaseGenericAbsorber"
raise ValueError(msg)
interp_handler["pressure_pa"] = ("z", atmo.pressure_pa)
if "wavelength_nm" in coords:
if atmo.wavelengths_nm is None:
msg = "wavelengths_nm must be specified in Atmosphere to use OpticalDatabaseGenericAbsorber"
interp_handler["wavelength_nm"] = atmo.wavelengths_nm
if "wavenumber_cminv" in coords:
if atmo.wavenumbers_cminv is None:
msg = "wavenumber_cminv must be specified in Atmosphere to use OpticalDatabaseGenericAbsorber"
interp_handler["wavenumber_cminv"] = atmo.wavenumbers_cminv
return interp_handler
def atmosphere_quantities(self, atmo: sk.Atmosphere, **kwargs) -> OpticalQuantities:
interp_handler = self._construct_interp_handler(atmo, **kwargs)
try:
np.testing.assert_equal(self._cached_interp_handler, interp_handler)
return self._quants
except (ValueError, AssertionError):
self._quants = OpticalQuantities(ssa=np.zeros_like(atmo.storage.ssa))
self._quants.extinction = (
self._database["xs"].interp(**interp_handler).to_numpy()
)
# Out of bounds, set to 0
self._quants.extinction[np.isnan(self._quants.extinction)] = 0
self._cached_interp_handler = deepcopy(interp_handler)
return deepcopy(self._quants)
[docs]
class OpticalDatabaseGenericScatterer(OpticalDatabase):
[docs]
def __init__(self, db_filepath: Path) -> None:
"""
A purely scattering optical property defined by a database file. The database must contain the following
- xs_total : The total cross section in [m^2]
- xs_scattering : The scattering cross section in [m^2]
- lm_a1 : the legendre coefficients for the phase function
All variables must be a function of either wavelength_nm or wavenumber_cminv, and optionally any other dimension such
as particle size.
Parameters
----------
db_filepath : Path
Path to the database file
"""
super().__init__(db_filepath)
self._validate_db()
def _validate_db(self):
self._database["lm_a1"] /= self._database["lm_a1"].isel(legendre=0)
def _construct_interp_handler(self, atmo: Atmosphere, **kwargs) -> dict:
coords = self._database["xs_total"].coords
interp_handler = {}
if "wavelength_nm" in coords:
if atmo.wavelengths_nm is None:
msg = "wavelengths_nm must be specified in Atmosphere to use OpticalDatabaseGenericScatterer"
raise ValueError(msg)
interp_handler["wavelength_nm"] = atmo.wavelengths_nm
if "wavenumber_cminv" in coords:
if atmo.wavenumber_cminv is None:
msg = "wavenumber_cminv must be specified in Atmosphere to use OpticalDatabaseGenericScatterer"
raise ValueError(msg)
interp_handler["wavenumber_cminv"] = atmo.wavenumber_cminv
for name, vals in kwargs.items():
if name in coords:
interp_handler[name] = ("z", vals)
return interp_handler
def cross_sections(
self, wavelengths_nm: np.array, altitudes_m: np.array, **kwargs # noqa: ARG002
) -> OpticalQuantities:
quants = OpticalQuantities()
coords = self._database["xs_total"].coords
interp_handler = {}
interp_handler["wavelength_nm"] = wavelengths_nm
for name, vals in kwargs.items():
if name in coords:
interp_handler[name] = ("z", vals)
ds_interp = self._database.interp(**interp_handler)
quants.extinction = ds_interp["xs_total"].to_numpy()
quants.ssa = ds_interp["xs_scattering"].to_numpy() / quants.extinction
quants.extinction[np.isnan(quants.extinction)] = 0
quants.ssa[np.isnan(quants.ssa)] = 0
return quants
def atmosphere_quantities(self, atmo: Atmosphere, **kwargs) -> OpticalQuantities:
quants = OpticalQuantities()
interp_handler = self._construct_interp_handler(atmo, **kwargs)
num_assign_legendre = min(
atmo.leg_coeff.a1.shape[0], len(self._database["legendre"])
)
if atmo.nstokes == 1:
drop_vars = ["lm_a2", "lm_a3", "lm_a4", "lm_b1", "lm_b2"]
elif atmo.nstokes == 3:
drop_vars = ["lm_a4", "lm_b2"]
else:
drop_vars = []
ds_interp = (
self._database.isel(legendre=slice(0, num_assign_legendre))
.drop(drop_vars)
.interp(**interp_handler)
)
if "z" not in ds_interp.dims:
# Our dataset has no z dependence
ds_interp = ds_interp.expand_dims(
dim={"z": len(atmo.model_geometry.altitudes())}
)
quants.extinction = np.copy(
ds_interp["xs_total"].transpose("z", "wavelength_nm").to_numpy()
)
quants.ssa = np.copy(
ds_interp["xs_scattering"].transpose("z", "wavelength_nm").to_numpy()
)
quants.leg_coeff = np.zeros_like(atmo.storage.leg_coeff)
leg_coeff = LegendreStorageView(quants.leg_coeff, atmo.nstokes)
leg_coeff.a1[:] = (
(ds_interp["lm_a1"])
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
# Renormalize leg_coeffs so a1 is always 1
leg_coeff.a1[:] /= quants.leg_coeff[0, :, :][np.newaxis, :, :]
quants.extinction[np.isnan(quants.extinction)] = 0
quants.ssa[np.isnan(quants.ssa)] = 0
if atmo.nstokes == 3:
# TODO: add pol properties
leg_coeff.a2[:] = (
(ds_interp["lm_a2"])
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
leg_coeff.a3[:] = (
(ds_interp["lm_a3"])
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
leg_coeff.b1[:] = (
(ds_interp["lm_b1"])
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
quants.leg_coeff[np.isnan(quants.leg_coeff)] = 0
return quants
def optical_derivatives(self, atmo: Atmosphere, **kwargs) -> dict:
derivs = {}
interp_handler = self._construct_interp_handler(atmo, **kwargs)
# Split the interpolators into ones that are 'z' dependent and ones that are not
interp_handler_z = {}
interp_handler_noz = {}
for key, val in interp_handler.items():
if val[0] == "z":
interp_handler_z[key] = val
else:
interp_handler_noz[key] = val
num_assign_legendre = min(
atmo.leg_coeff.a1.shape[0], len(self._database["legendre"])
)
if atmo.nstokes == 1:
drop_vars = ["lm_a2", "lm_a3", "lm_a4", "lm_b1", "lm_b2"]
elif atmo.nstokes == 3:
drop_vars = ["lm_a4", "lm_b2"]
else:
drop_vars = []
# Get the derivatives of the cross section with respect to the z dependent variables
partial_interp = (
self._database.isel(legendre=slice(0, num_assign_legendre))
.drop(drop_vars)
.interp(**interp_handler_noz)
)
for key, val in interp_handler_z.items():
# If the db only contains one element in the dimension we can't take a derivative
if len(self._database[key]) == 1:
continue
# Interpolate over the other variables
new_interpolants = {k: v for k, v in interp_handler_z.items() if k != key}
partial_interp2 = partial_interp.interp(**new_interpolants)
dT = partial_interp2.diff(key) / partial_interp2[key].diff(key)
interp_index = np.argmax(dT[key].to_numpy() > val[1][:, np.newaxis], axis=1)
if "z" in dT.dims:
dT = dT.isel(
{
"z": xr.DataArray(list(range(len(interp_index))), dims="z"),
key: xr.DataArray(interp_index, dims="z"),
}
)
else:
dT = dT.isel({key: xr.DataArray(interp_index, dims="z")})
derivs[key] = NativeGridDerivative(
d_extinction=dT["xs_total"].transpose("z", "wavelength_nm").to_numpy(),
d_ssa=dT["xs_scattering"].transpose("z", "wavelength_nm").to_numpy(),
d_leg_coeff=np.zeros_like(atmo.storage.leg_coeff),
)
d_leg_coeff = LegendreStorageView(derivs[key].d_leg_coeff, atmo.nstokes)
d_leg_coeff.a1[:] = (
dT["lm_a1"]
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
if atmo.nstokes == 3:
d_leg_coeff.a2[:] = (
dT["lm_a2"]
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
d_leg_coeff.a3[:] = (
dT["lm_a3"]
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
d_leg_coeff.b1[:] = (
dT["lm_b1"]
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
derivs[key].d_extinction[np.isnan(derivs[key].d_extinction)] = 0
derivs[key].d_ssa[np.isnan(derivs[key].d_ssa)] = 0
derivs[key].d_leg_coeff[np.isnan(derivs[key].d_leg_coeff)] = 0
return derivs
def cross_section_derivatives(
self, wavelengths_nm: np.array, altitudes_m: np.array, **kwargs # noqa: ARG002
) -> dict:
derivs = {}
coords = self._database["xs_total"].coords
interp_handler = {}
interp_handler["wavelength_nm"] = wavelengths_nm
for name, vals in kwargs.items():
if name in coords:
interp_handler[name] = ("z", vals)
# Split the interpolators into ones that are 'z' dependent and ones that are not
interp_handler_z = {}
interp_handler_noz = {}
for key, val in interp_handler.items():
if val[0] == "z":
interp_handler_z[key] = val
else:
interp_handler_noz[key] = val
# Get the derivatives of the cross section with respect to the z dependent variables
partial_interp = self._database["xs_total"].interp(**interp_handler_noz)
for key, val in interp_handler_z.items():
# If the db only contains one element in the dimension we can't take a derivative
if len(self._database[key]) == 1:
continue
# Interpolate over the other variables
new_interpolants = {k: v for k, v in interp_handler_z.items() if k != key}
partial_interp2 = partial_interp.interp(**new_interpolants)
dT = partial_interp2.diff(key) / partial_interp2[key].diff(key)
interp_index = np.argmax(dT[key].to_numpy() > val[1][:, np.newaxis], axis=1)
if "z" in dT.dims:
dT = dT.isel(
{
"z": xr.DataArray(list(range(len(interp_index))), dims="z"),
key: xr.DataArray(interp_index, dims="z"),
}
)
else:
dT = dT.isel({key: xr.DataArray(interp_index, dims="z")})
derivs[key] = dT.to_numpy().flatten()
return derivs