Source code for sasktran2.optical.database

from __future__ import annotations

from pathlib import Path

import numpy as np
import xarray as xr

from sasktran2._core_rust import (
    AbsorberDatabaseDim1,
    AbsorberDatabaseDim2,
    AbsorberDatabaseDim3,
)
from sasktran2.atmosphere import Atmosphere, NativeGridDerivative
from sasktran2.optical.base import OpticalProperty, OpticalQuantities
from sasktran2.polarization import LegendreStorageView


[docs] class OpticalDatabase(OpticalProperty):
[docs] def __init__(self, db_filepath: Path) -> None: """ An optical property that is defined by a database file. This is just a base class to handle file loading, derived classes must be used as actual optical properties. Parameters ---------- db_filepath : Path Path to the optical database file Raises ------ OSError If xarray is not installed """ super().__init__() self._file = db_filepath try: import xarray as xr except ImportError: msg = "xarray must be installed to use OpticalDatabaseGenericAbsorber" raise msg from OSError self._database = xr.open_dataset(self._file) self._validate_db()
[docs] class OpticalDatabaseGenericAbsorber(OpticalDatabase):
[docs] def __init__(self, db_filepath: Path) -> None: """ A purely absorbing optical property defined by a database file. The database must contain the following - xs : The absorption cross section in [m^2] xs must be a function of either wavelength_nm or wavenumber_cminv, and optionally temperature and pressure. Parameters ---------- db_filepath : Path Path to the database file """ super().__init__(db_filepath)
def _validate_db(self): if type(self._database) is not xr.Dataset: return # Old file format used "temperature" and "pressure" instead of the standard keys # "temperature_k" and "pressure_pa", if so we rename them if "temperature" in self._database: self._database = self._database.rename({"temperature": "temperature_k"}) if "pressure" in self._database: self._database = self._database.rename({"pressure": "pressure_pa"}) if "xs" not in self._database: msg = "xs must be defined in the optical database" raise ValueError(msg) coords = list(self._database["xs"].coords) if "wavenumber_cminv" in coords: wavenumber_cminv = self._database["wavenumber_cminv"].to_numpy() else: wavenumber_cminv = 1e7 / self._database["wavelength_nm"].to_numpy() coords[-1] = "wavenumber_cminv" if len(coords) == 3: param0 = self._database[coords[0]].to_numpy() param1 = self._database[coords[1]].to_numpy() sidx0 = np.argsort(param0) sidx1 = np.argsort(param1) sidx2 = np.argsort(wavenumber_cminv) xs = self._database["xs"].to_numpy()[sidx0][:, sidx1][:, :, sidx2].copy() self._database = AbsorberDatabaseDim3( wavenumber_cminv[sidx2].astype(np.float64), param0[sidx0].astype(np.float64), param1[sidx1].astype(np.float64), xs, coords[:-1], ) self._coords = coords elif len(coords) == 2: param0 = self._database[coords[0]].to_numpy() sidx0 = np.argsort(param0) sidx1 = np.argsort(wavenumber_cminv) xs = self._database["xs"].to_numpy()[sidx0][:, sidx1].copy() self._database = AbsorberDatabaseDim2( wavenumber_cminv[sidx1].astype(np.float64), param0[sidx0].astype(np.float64), xs, coords[:-1], ) self._coords = coords elif len(coords) == 1: sidx1 = np.argsort(wavenumber_cminv) xs = self._database["xs"].to_numpy()[sidx1].copy() self._database = AbsorberDatabaseDim1( wavenumber_cminv[sidx1].astype(np.float64), xs, ) self._coords = coords def atmosphere_quantities(self, atmo, **kwargs): # When called by the constituent this should be elided return self._database.atmosphere_quantities(atmo) def optical_derivatives(self, atmo, **kwargs): # When called by the constituent this should be elided return self._database.optical_derivatives(atmo) def _into_rust_object(self): return self._database
[docs] class OpticalDatabaseGenericScatterer(OpticalDatabase):
[docs] def __init__(self, db_filepath: Path) -> None: """ A purely scattering optical property defined by a database file. The database must contain the following - xs_total : The total cross section in [m^2] - xs_scattering : The scattering cross section in [m^2] - lm_a1 : the legendre coefficients for the phase function All variables must be a function of either wavelength_nm or wavenumber_cminv, and optionally any other dimension such as particle size. Parameters ---------- db_filepath : Path Path to the database file """ super().__init__(db_filepath) self._validate_db()
def _validate_db(self): self._database["lm_a1"] /= self._database["lm_a1"].isel(legendre=0) def _construct_interp_handler(self, atmo: Atmosphere, **kwargs) -> dict: coords = self._database["xs_total"].coords interp_handler = {} if "wavelength_nm" in coords: if atmo.wavelengths_nm is None: msg = "wavelengths_nm must be specified in Atmosphere to use OpticalDatabaseGenericScatterer" raise ValueError(msg) interp_handler["wavelength_nm"] = atmo.wavelengths_nm if "wavenumber_cminv" in coords: if atmo.wavenumber_cminv is None: msg = "wavenumber_cminv must be specified in Atmosphere to use OpticalDatabaseGenericScatterer" raise ValueError(msg) interp_handler["wavenumber_cminv"] = atmo.wavenumber_cminv for name, vals in kwargs.items(): if name in coords: interp_handler[name] = ("z", vals) return interp_handler def cross_sections( self, wavelengths_nm: np.array, altitudes_m: np.array, **kwargs # noqa: ARG002 ) -> OpticalQuantities: quants = OpticalQuantities() coords = self._database["xs_total"].coords interp_handler = {} interp_handler["wavelength_nm"] = wavelengths_nm for name, vals in kwargs.items(): if name in coords: interp_handler[name] = ("z", vals) ds_interp = self._database.interp(**interp_handler) quants.extinction = ds_interp["xs_total"].to_numpy() quants.ssa = ds_interp["xs_scattering"].to_numpy() / quants.extinction quants.extinction[np.isnan(quants.extinction)] = 0 quants.ssa[np.isnan(quants.ssa)] = 0 return quants def atmosphere_quantities(self, atmo: Atmosphere, **kwargs) -> OpticalQuantities: quants = OpticalQuantities() interp_handler = self._construct_interp_handler(atmo, **kwargs) num_assign_legendre = min( atmo.leg_coeff.a1.shape[0], len(self._database["legendre"]) ) if atmo.nstokes == 1: drop_vars = ["lm_a2", "lm_a3", "lm_a4", "lm_b1", "lm_b2"] elif atmo.nstokes == 3: drop_vars = ["lm_a4", "lm_b2"] else: drop_vars = [] ds_interp = ( self._database.isel(legendre=slice(0, num_assign_legendre)) .drop_vars(drop_vars) .interp(**interp_handler) ) if "z" not in ds_interp.dims: # Our dataset has no z dependence ds_interp = ds_interp.expand_dims( dim={"z": len(atmo.model_geometry.altitudes())} ) quants.extinction = np.copy( ds_interp["xs_total"].transpose("z", "wavelength_nm").to_numpy() ) quants.ssa = np.copy( ds_interp["xs_scattering"].transpose("z", "wavelength_nm").to_numpy() ) quants.leg_coeff = np.zeros_like(atmo.storage.leg_coeff) leg_coeff = LegendreStorageView(quants.leg_coeff, atmo.nstokes) leg_coeff.a1[:] = ( (ds_interp["lm_a1"]) .transpose("legendre", "z", "wavelength_nm") .to_numpy()[:num_assign_legendre] ) # Renormalize leg_coeffs so a1 is always 1 leg_coeff.a1[:] /= quants.leg_coeff[0, :, :][np.newaxis, :, :] quants.extinction[np.isnan(quants.extinction)] = 0 quants.ssa[np.isnan(quants.ssa)] = 0 if atmo.nstokes == 3: # TODO: add pol properties leg_coeff.a2[:] = ( (ds_interp["lm_a2"]) .transpose("legendre", "z", "wavelength_nm") .to_numpy()[:num_assign_legendre] ) leg_coeff.a3[:] = ( (ds_interp["lm_a3"]) .transpose("legendre", "z", "wavelength_nm") .to_numpy()[:num_assign_legendre] ) leg_coeff.b1[:] = ( (ds_interp["lm_b1"]) .transpose("legendre", "z", "wavelength_nm") .to_numpy()[:num_assign_legendre] ) quants.leg_coeff[np.isnan(quants.leg_coeff)] = 0 return quants def optical_derivatives(self, atmo: Atmosphere, **kwargs) -> dict: derivs = {} interp_handler = self._construct_interp_handler(atmo, **kwargs) # Split the interpolators into ones that are 'z' dependent and ones that are not interp_handler_z = {} interp_handler_noz = {} for key, val in interp_handler.items(): if val[0] == "z": interp_handler_z[key] = val else: interp_handler_noz[key] = val num_assign_legendre = min( atmo.leg_coeff.a1.shape[0], len(self._database["legendre"]) ) if atmo.nstokes == 1: drop_vars = ["lm_a2", "lm_a3", "lm_a4", "lm_b1", "lm_b2"] elif atmo.nstokes == 3: drop_vars = ["lm_a4", "lm_b2"] else: drop_vars = [] # Get the derivatives of the cross section with respect to the z dependent variables partial_interp = ( self._database.isel(legendre=slice(0, num_assign_legendre)) .drop_vars(drop_vars) .interp(**interp_handler_noz) ) for key, val in interp_handler_z.items(): # If the db only contains one element in the dimension we can't take a derivative if len(self._database[key]) == 1: continue # Interpolate over the other variables new_interpolants = {k: v for k, v in interp_handler_z.items() if k != key} partial_interp2 = partial_interp.interp(**new_interpolants) dT = partial_interp2.diff(key) / partial_interp2[key].diff(key) interp_index = np.argmax(dT[key].to_numpy() > val[1][:, np.newaxis], axis=1) if "z" in dT.dims: dT = dT.isel( { "z": xr.DataArray(list(range(len(interp_index))), dims="z"), key: xr.DataArray(interp_index, dims="z"), } ) else: dT = dT.isel({key: xr.DataArray(interp_index, dims="z")}) derivs[key] = NativeGridDerivative( d_extinction=dT["xs_total"].transpose("z", "wavelength_nm").to_numpy(), d_ssa=dT["xs_scattering"].transpose("z", "wavelength_nm").to_numpy(), d_leg_coeff=np.zeros_like(atmo.storage.leg_coeff), ) d_leg_coeff = LegendreStorageView(derivs[key].d_leg_coeff, atmo.nstokes) d_leg_coeff.a1[:] = ( dT["lm_a1"] .transpose("legendre", "z", "wavelength_nm") .to_numpy()[:num_assign_legendre] ) if atmo.nstokes == 3: d_leg_coeff.a2[:] = ( dT["lm_a2"] .transpose("legendre", "z", "wavelength_nm") .to_numpy()[:num_assign_legendre] ) d_leg_coeff.a3[:] = ( dT["lm_a3"] .transpose("legendre", "z", "wavelength_nm") .to_numpy()[:num_assign_legendre] ) d_leg_coeff.b1[:] = ( dT["lm_b1"] .transpose("legendre", "z", "wavelength_nm") .to_numpy()[:num_assign_legendre] ) derivs[key].d_extinction[np.isnan(derivs[key].d_extinction)] = 0 derivs[key].d_ssa[np.isnan(derivs[key].d_ssa)] = 0 derivs[key].d_leg_coeff[np.isnan(derivs[key].d_leg_coeff)] = 0 return derivs def cross_section_derivatives( self, wavelengths_nm: np.array, altitudes_m: np.array, **kwargs # noqa: ARG002 ) -> dict: derivs = {} coords = self._database["xs_total"].coords interp_handler = {} interp_handler["wavelength_nm"] = wavelengths_nm for name, vals in kwargs.items(): if name in coords: interp_handler[name] = ("z", vals) # Split the interpolators into ones that are 'z' dependent and ones that are not interp_handler_z = {} interp_handler_noz = {} for key, val in interp_handler.items(): if val[0] == "z": interp_handler_z[key] = val else: interp_handler_noz[key] = val # Get the derivatives of the cross section with respect to the z dependent variables partial_interp = self._database["xs_total"].interp(**interp_handler_noz) for key, val in interp_handler_z.items(): # If the db only contains one element in the dimension we can't take a derivative if len(self._database[key]) == 1: continue # Interpolate over the other variables new_interpolants = {k: v for k, v in interp_handler_z.items() if k != key} partial_interp2 = partial_interp.interp(**new_interpolants) dT = partial_interp2.diff(key) / partial_interp2[key].diff(key) interp_index = np.argmax(dT[key].to_numpy() > val[1][:, np.newaxis], axis=1) if "z" in dT.dims: dT = dT.isel( { "z": xr.DataArray(list(range(len(interp_index))), dims="z"), key: xr.DataArray(interp_index, dims="z"), } ) else: dT = dT.isel({key: xr.DataArray(interp_index, dims="z")}) derivs[key] = dT.to_numpy().flatten() return derivs