from __future__ import annotations
from copy import copy
from pathlib import Path
import numpy as np
import xarray as xr
from sasktran2._core_rust import (
AbsorberDatabaseDim1,
AbsorberDatabaseDim2,
AbsorberDatabaseDim3,
PyScatteringDatabaseDim1,
PyScatteringDatabaseDim2,
PyScatteringDatabaseDim3,
PyScatteringDatabaseDim4,
)
from sasktran2.atmosphere import Atmosphere, NativeGridDerivative
from sasktran2.optical.base import OpticalProperty, OpticalQuantities
from sasktran2.polarization import LegendreStorageView
from .quantities import OpticalQuantities as RustOpticalQuantities
[docs]
class OpticalDatabase(OpticalProperty):
[docs]
def __init__(
self, db_filepath: Path | None = None, db: xr.Dataset | None = None
) -> None:
"""
An optical property that is defined by a database file. This is just a base class to handle file loading,
derived classes must be used as actual optical properties.
Parameters
----------
db_filepath : Path
Path to the optical database file
Raises
------
OSError
If xarray is not installed
"""
super().__init__()
if db_filepath is None and db is None:
msg = "Either db_filepath or db must be provided to OpticalDatabase"
raise ValueError(msg)
if db is not None:
self._database = db
else:
self._file = db_filepath
self._database = xr.open_dataset(self._file)
self._validate_db()
[docs]
class OpticalDatabaseGenericAbsorber(OpticalDatabase):
[docs]
def __init__(self, db_filepath: Path) -> None:
"""
A purely absorbing optical property defined by a database file. The database must contain the following
- xs : The absorption cross section in [m^2]
xs must be a function of either wavelength_nm or wavenumber_cminv, and optionally temperature and pressure.
Parameters
----------
db_filepath : Path
Path to the database file
"""
super().__init__(db_filepath)
def _validate_db(self):
if type(self._database) is not xr.Dataset:
return
# Old file format used "temperature" and "pressure" instead of the standard keys
# "temperature_k" and "pressure_pa", if so we rename them
if "temperature" in self._database:
self._database = self._database.rename({"temperature": "temperature_k"})
if "pressure" in self._database:
self._database = self._database.rename({"pressure": "pressure_pa"})
if "xs" not in self._database:
msg = "xs must be defined in the optical database"
raise ValueError(msg)
coords = list(self._database["xs"].coords)
if "wavenumber_cminv" in coords:
wavenumber_cminv = self._database["wavenumber_cminv"].to_numpy()
else:
wavenumber_cminv = 1e7 / self._database["wavelength_nm"].to_numpy()
coords[-1] = "wavenumber_cminv"
if len(coords) == 3:
param0 = self._database[coords[0]].to_numpy()
param1 = self._database[coords[1]].to_numpy()
sidx0 = np.argsort(param0)
sidx1 = np.argsort(param1)
sidx2 = np.argsort(wavenumber_cminv)
xs = self._database["xs"].to_numpy()[sidx0][:, sidx1][:, :, sidx2].copy()
self._database = AbsorberDatabaseDim3(
wavenumber_cminv[sidx2].astype(np.float64),
param0[sidx0].astype(np.float64),
param1[sidx1].astype(np.float64),
xs,
coords[:-1],
)
self._coords = coords
elif len(coords) == 2:
param0 = self._database[coords[0]].to_numpy()
sidx0 = np.argsort(param0)
sidx1 = np.argsort(wavenumber_cminv)
xs = self._database["xs"].to_numpy()[sidx0][:, sidx1].copy()
self._database = AbsorberDatabaseDim2(
wavenumber_cminv[sidx1].astype(np.float64),
param0[sidx0].astype(np.float64),
xs,
coords[:-1],
)
self._coords = coords
elif len(coords) == 1:
sidx1 = np.argsort(wavenumber_cminv)
xs = self._database["xs"].to_numpy()[sidx1].copy()
self._database = AbsorberDatabaseDim1(
wavenumber_cminv[sidx1].astype(np.float64),
xs,
)
self._coords = coords
def atmosphere_quantities(self, atmo, **kwargs):
# When called by the constituent this should be elided
return self._database.atmosphere_quantities(atmo)
def optical_derivatives(self, atmo, **kwargs):
# When called by the constituent this should be elided
return self._database.optical_derivatives(atmo)
def _into_rust_object(self):
return self._database
[docs]
class OpticalDatabaseGenericScattererRust(OpticalDatabase):
[docs]
def __init__(self, db_filepath: Path) -> None:
"""
A purely scattering optical property defined by a database file. The database must contain the following
- xs_total : The total cross section in [m^2]
- xs_scattering : The scattering cross section in [m^2]
- lm_a1 : the legendre coefficients for the phase function
All variables must be a function of either wavelength_nm or wavenumber_cminv, and optionally any other dimension such
as particle size.
This differs from OpticalDatabaseGenericScatterer in that it uses the Rust backend for the interpolation,
other than that the two classes are identical.
Parameters
----------
db_filepath : Path
Path to the database file
"""
super().__init__(db_filepath)
self._validate_db()
# Reorient the dimensions
dims = list(self._database["xs_total"].isel(wavelength_nm=0).dims)
db = self._database.transpose(*dims, "wavelength_nm", "legendre", ...)
# construct internal object
xs = db["xs_total"].to_numpy()
ssa = db["xs_scattering"].to_numpy()
db_shape = db["lm_a1"].shape
leg_shape = np.atleast_1d(copy(db_shape))
leg_shape[-1] *= 6
legendre = np.zeros(leg_shape)
legendre[..., 0::6] = db["lm_a1"].to_numpy()
legendre[..., 1::6] = db["lm_a2"].to_numpy()
legendre[..., 2::6] = db["lm_a3"].to_numpy()
legendre[..., 3::6] = db["lm_a4"].to_numpy()
legendre[..., 4::6] = db["lm_b1"].to_numpy()
legendre[..., 5::6] = db["lm_b2"].to_numpy()
wvnum = 1e7 / db["wavelength_nm"].to_numpy()
sidx = np.argsort(wvnum)
if len(xs.shape) == 1:
self._db = PyScatteringDatabaseDim1(
xs[sidx], ssa[sidx], legendre[sidx], wvnum[sidx]
)
elif len(xs.shape) == 2:
param_names = list(db["xs_total"].dims)[:-1]
param0 = db[param_names[0]].to_numpy()
self._db = PyScatteringDatabaseDim2(
xs[:, sidx],
ssa[:, sidx],
legendre[:, sidx, :],
wvnum[sidx],
np.atleast_1d(param0).astype(np.float64),
param_names,
)
elif len(xs.shape) == 3:
param_names = list(db["xs_total"].dims)[:-1]
param0 = db[param_names[0]].to_numpy()
param1 = db[param_names[1]].to_numpy()
self._db = PyScatteringDatabaseDim3(
xs[:, :, sidx],
ssa[:, :, sidx],
legendre[:, :, sidx, :],
wvnum[sidx],
np.atleast_1d(param0).astype(np.float64),
np.atleast_1d(param1).astype(np.float64),
param_names,
)
elif len(xs.shape) == 4:
param_names = list(db["xs_total"].dims)[:-1]
param0 = db[param_names[0]].to_numpy()
param1 = db[param_names[1]].to_numpy()
param2 = db[param_names[2]].to_numpy()
self._db = PyScatteringDatabaseDim4(
xs[:, :, :, sidx],
ssa[:, :, :, sidx],
legendre[:, :, :, sidx, :],
wvnum[sidx],
np.atleast_1d(param0).astype(np.float64),
np.atleast_1d(param1).astype(np.float64),
np.atleast_1d(param2).astype(np.float64),
param_names,
)
def _validate_db(self):
self._database["lm_a1"] /= self._database["lm_a1"].isel(legendre=0)
def cross_sections(
self, wavelengths_nm: np.array, altitudes_m: np.array, **kwargs
) -> OpticalQuantities:
return self._db.cross_sections(
np.atleast_1d(wavelengths_nm).astype(float),
np.atleast_1d(altitudes_m).astype(float),
**kwargs,
)
def atmosphere_quantities(self, atmo: Atmosphere, **kwargs) -> OpticalQuantities:
return RustOpticalQuantities(self._db.atmosphere_quantities(atmo, **kwargs))
def optical_derivatives(self, atmo: Atmosphere, **kwargs) -> dict:
result = self._db.optical_derivatives(atmo, **kwargs)
return {k: RustOpticalQuantities(v) for k, v in result.items()}
def cross_section_derivatives(
self, wavelengths_nm: np.array, altitudes_m: np.array, **kwargs
) -> dict:
result = self._db.cross_section_derivatives(
np.atleast_1d(wavelengths_nm).astype(float),
np.atleast_1d(altitudes_m).astype(float),
**kwargs,
)
return {k: v.extinction.flatten() for k, v in result.items()}
[docs]
class OpticalDatabaseGenericScatterer(OpticalDatabase):
[docs]
def __init__(self, db_filepath: Path) -> None:
"""
A purely scattering optical property defined by a database file. The database must contain the following
- xs_total : The total cross section in [m^2]
- xs_scattering : The scattering cross section in [m^2]
- lm_a1 : the legendre coefficients for the phase function
All variables must be a function of either wavelength_nm or wavenumber_cminv, and optionally any other dimension such
as particle size.
Parameters
----------
db_filepath : Path
Path to the database file
"""
super().__init__(db_filepath)
self._validate_db()
def _validate_db(self):
self._database["lm_a1"] /= self._database["lm_a1"].isel(legendre=0)
def _construct_interp_handler(self, atmo: Atmosphere, **kwargs) -> dict:
coords = self._database["xs_total"].coords
interp_handler = {}
if "wavelength_nm" in coords:
if atmo.wavelengths_nm is None:
msg = "wavelengths_nm must be specified in Atmosphere to use OpticalDatabaseGenericScatterer"
raise ValueError(msg)
interp_handler["wavelength_nm"] = atmo.wavelengths_nm
if "wavenumber_cminv" in coords:
if atmo.wavenumber_cminv is None:
msg = "wavenumber_cminv must be specified in Atmosphere to use OpticalDatabaseGenericScatterer"
raise ValueError(msg)
interp_handler["wavenumber_cminv"] = atmo.wavenumber_cminv
for name, vals in kwargs.items():
if name in coords:
interp_handler[name] = ("z", vals)
return interp_handler
def cross_sections(
self, wavelengths_nm: np.array, altitudes_m: np.array, **kwargs # noqa: ARG002
) -> OpticalQuantities:
quants = OpticalQuantities()
coords = self._database["xs_total"].coords
interp_handler = {}
interp_handler["wavelength_nm"] = wavelengths_nm
for name, vals in kwargs.items():
if name in coords:
interp_handler[name] = ("z", vals)
ds_interp = self._database.interp(**interp_handler)
quants.extinction = ds_interp["xs_total"].to_numpy()
quants.ssa = ds_interp["xs_scattering"].to_numpy() / quants.extinction
quants.extinction[np.isnan(quants.extinction)] = 0
quants.ssa[np.isnan(quants.ssa)] = 0
return quants
def atmosphere_quantities(self, atmo: Atmosphere, **kwargs) -> OpticalQuantities:
quants = OpticalQuantities()
interp_handler = self._construct_interp_handler(atmo, **kwargs)
num_assign_legendre = min(
atmo.leg_coeff.a1.shape[0], len(self._database["legendre"])
)
if atmo.nstokes == 1:
drop_vars = ["lm_a2", "lm_a3", "lm_a4", "lm_b1", "lm_b2"]
elif atmo.nstokes == 3:
drop_vars = ["lm_a4", "lm_b2"]
else:
drop_vars = []
ds_interp = (
self._database.isel(legendre=slice(0, num_assign_legendre))
.drop_vars(drop_vars)
.interp(**interp_handler)
)
if "z" not in ds_interp.dims:
# Our dataset has no z dependence
ds_interp = ds_interp.expand_dims(
dim={"z": len(atmo.model_geometry.altitudes())}
)
quants.extinction = np.copy(
ds_interp["xs_total"].transpose("z", "wavelength_nm").to_numpy()
)
quants.ssa = np.copy(
ds_interp["xs_scattering"].transpose("z", "wavelength_nm").to_numpy()
)
quants.leg_coeff = np.zeros_like(atmo.storage.leg_coeff)
leg_coeff = LegendreStorageView(quants.leg_coeff, atmo.nstokes)
leg_coeff.a1[:] = (
(ds_interp["lm_a1"])
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
# Renormalize leg_coeffs so a1 is always 1
leg_coeff.a1[:] /= quants.leg_coeff[0, :, :][np.newaxis, :, :]
quants.extinction[np.isnan(quants.extinction)] = 0
quants.ssa[np.isnan(quants.ssa)] = 0
if atmo.nstokes == 3:
# TODO: add pol properties
leg_coeff.a2[:] = (
(ds_interp["lm_a2"])
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
leg_coeff.a3[:] = (
(ds_interp["lm_a3"])
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
leg_coeff.b1[:] = (
(ds_interp["lm_b1"])
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
quants.leg_coeff[np.isnan(quants.leg_coeff)] = 0
return quants
def optical_derivatives(self, atmo: Atmosphere, **kwargs) -> dict:
derivs = {}
interp_handler = self._construct_interp_handler(atmo, **kwargs)
# Split the interpolators into ones that are 'z' dependent and ones that are not
interp_handler_z = {}
interp_handler_noz = {}
for key, val in interp_handler.items():
if val[0] == "z":
interp_handler_z[key] = val
else:
interp_handler_noz[key] = val
num_assign_legendre = min(
atmo.leg_coeff.a1.shape[0], len(self._database["legendre"])
)
if atmo.nstokes == 1:
drop_vars = ["lm_a2", "lm_a3", "lm_a4", "lm_b1", "lm_b2"]
elif atmo.nstokes == 3:
drop_vars = ["lm_a4", "lm_b2"]
else:
drop_vars = []
# Get the derivatives of the cross section with respect to the z dependent variables
partial_interp = (
self._database.isel(legendre=slice(0, num_assign_legendre))
.drop_vars(drop_vars)
.interp(**interp_handler_noz)
)
for key, val in interp_handler_z.items():
# If the db only contains one element in the dimension we can't take a derivative
if len(self._database[key]) == 1:
continue
# Interpolate over the other variables
new_interpolants = {k: v for k, v in interp_handler_z.items() if k != key}
partial_interp2 = partial_interp.interp(**new_interpolants)
dT = partial_interp2.diff(key) / partial_interp2[key].diff(key)
interp_index = np.argmax(dT[key].to_numpy() > val[1][:, np.newaxis], axis=1)
if "z" in dT.dims:
dT = dT.isel(
{
"z": xr.DataArray(list(range(len(interp_index))), dims="z"),
key: xr.DataArray(interp_index, dims="z"),
}
)
else:
dT = dT.isel({key: xr.DataArray(interp_index, dims="z")})
derivs[key] = NativeGridDerivative(
d_extinction=dT["xs_total"].transpose("z", "wavelength_nm").to_numpy(),
d_ssa=dT["xs_scattering"].transpose("z", "wavelength_nm").to_numpy(),
d_leg_coeff=np.zeros_like(atmo.storage.leg_coeff),
)
d_leg_coeff = LegendreStorageView(derivs[key].d_leg_coeff, atmo.nstokes)
d_leg_coeff.a1[:] = (
dT["lm_a1"]
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
if atmo.nstokes == 3:
d_leg_coeff.a2[:] = (
dT["lm_a2"]
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
d_leg_coeff.a3[:] = (
dT["lm_a3"]
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
d_leg_coeff.b1[:] = (
dT["lm_b1"]
.transpose("legendre", "z", "wavelength_nm")
.to_numpy()[:num_assign_legendre]
)
derivs[key].d_extinction[np.isnan(derivs[key].d_extinction)] = 0
derivs[key].d_ssa[np.isnan(derivs[key].d_ssa)] = 0
derivs[key].d_leg_coeff[np.isnan(derivs[key].d_leg_coeff)] = 0
return derivs
def cross_section_derivatives(
self, wavelengths_nm: np.array, altitudes_m: np.array, **kwargs # noqa: ARG002
) -> dict:
derivs = {}
coords = self._database["xs_total"].coords
interp_handler = {}
interp_handler["wavelength_nm"] = wavelengths_nm
for name, vals in kwargs.items():
if name in coords:
interp_handler[name] = ("z", vals)
# Split the interpolators into ones that are 'z' dependent and ones that are not
interp_handler_z = {}
interp_handler_noz = {}
for key, val in interp_handler.items():
if val[0] == "z":
interp_handler_z[key] = val
else:
interp_handler_noz[key] = val
# Get the derivatives of the cross section with respect to the z dependent variables
partial_interp = self._database["xs_total"].interp(**interp_handler_noz)
for key, val in interp_handler_z.items():
# If the db only contains one element in the dimension we can't take a derivative
if len(self._database[key]) == 1:
continue
# Interpolate over the other variables
new_interpolants = {k: v for k, v in interp_handler_z.items() if k != key}
partial_interp2 = partial_interp.interp(**new_interpolants)
dT = partial_interp2.diff(key) / partial_interp2[key].diff(key)
interp_index = np.argmax(dT[key].to_numpy() > val[1][:, np.newaxis], axis=1)
if "z" in dT.dims:
dT = dT.isel(
{
"z": xr.DataArray(list(range(len(interp_index))), dims="z"),
key: xr.DataArray(interp_index, dims="z"),
}
)
else:
dT = dT.isel({key: xr.DataArray(interp_index, dims="z")})
derivs[key] = dT.to_numpy().flatten()
return derivs