Source code for sasktran2.database.mie

import hashlib
import json
from itertools import product
from pathlib import Path

import numpy as np
import xarray as xr

from sasktran2.mie.distribution import ParticleSizeDistribution
from sasktran2.mie.refractive import RefractiveIndex
from sasktran2.optical.database import OpticalDatabaseGenericScatterer

from .base import CachedDatabase


[docs] class MieDatabase(CachedDatabase, OpticalDatabaseGenericScatterer): def __init__( self, psize_distribution: ParticleSizeDistribution, refractive_index: RefractiveIndex, wavelengths_nm: np.array, db_root: Path | None = None, backend: str = "sasktran_legacy", **kwargs, ) -> None: """ A MieDatabase is a database that caches Mie scattering data on the local file system. The root directly can optionally be specified, otherwise the default folder is used. Parameters ---------- psize_distribution : ParticleSizeDistribution The particle size distribution refractive_index : RefractiveIndex The refractive index wavelengths_nm : np.array The wavelengths to calculate the Mie parameters at db_root : Path, optional The root directory to store the database, by default None backend : str, optional The backend to use, by default "sasktran_legacy" which requires the module `sasktran` to be installed. Currently supported options are ["sasktran_legacy"]. Default is "sasktran_legacy" kwargs Additional arguments to pass to the particle size distribution, these should match the psize_distribution.args() method """ class NumpyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): return obj.tolist() return json.JSONEncoder.default(self, obj) hasher = hashlib.sha1() encoded = json.dumps( {**kwargs, "wavelengths": wavelengths_nm}, sort_keys=True, cls=NumpyEncoder ).encode() hasher.update(encoded) identifier = hasher.hexdigest() CachedDatabase.__init__( self, db_root=db_root, rel_path=Path("mie") .joinpath(refractive_index.identifier) .joinpath(psize_distribution.identifier), ) self._data_file = self._db_root.joinpath(f"{identifier}.nc") self._psize_args = psize_distribution.args() self._psize_dist = psize_distribution self._refractive_index = refractive_index self._wavelengths_nm = wavelengths_nm self._backend = backend for arg in kwargs: if arg not in self._psize_args: msg = f"Invalid argument {arg} for particle size distribution" raise ValueError(msg) self._kwargs = kwargs OpticalDatabaseGenericScatterer.__init__(self, self.path()) def generate(self): if self._backend == "sasktran_legacy": self._generate_sasktran_legacy() else: msg = f"Invalid backend {self._backend}" raise ValueError(msg) def _generate_sasktran_legacy(self): """ Generates the data file """ try: from sasktran import MieWiscombe from sasktran.legendre import compute_greek_coefficients_legendre from sasktran.mie.distribution import integrated_mie except ImportError as err: msg = "sasktran_legacy is required to generate Mie databases, try pip install sasktran" raise ImportError(msg) from err def _create_table(distribution, refractive_index_fn, wavelengths): mie = MieWiscombe() vals = integrated_mie( mie, distribution, refractive_index_fn, wavelengths, num_quad=1000, maxintquantile=0.99999, ) # vals xs is in units of nm^2, convert to cm^2 then to m^2 vals["xs_total"] *= 1e-14 * 1e-4 vals["xs_scattering"] *= 1e-14 * 1e-4 vals["xs_absorption"] *= 1e-14 * 1e-4 lm_a1 = np.zeros_like(vals["lm_p11"].values) lm_a2 = np.zeros_like(lm_a1) lm_a3 = np.zeros_like(lm_a1) lm_a4 = np.zeros_like(lm_a1) lm_b1 = np.zeros_like(lm_a1) lm_b2 = np.zeros_like(lm_a1) for idx in range(len(vals.wavelength.values)): selected = vals.isel(wavelength=idx) ( lm_a1[idx, :], lm_a2[idx, :], lm_a3[idx, :], lm_a4[idx, :], lm_b1[idx, :], lm_b2[idx, :], ) = compute_greek_coefficients_legendre( selected["lm_p11"].values, selected["lm_p12"].values, selected["lm_p11"].values, selected["lm_p33"].values, selected["lm_p34"].values, selected["lm_p33"].values, ) vals["lm_a1"] = (["wavelength", "legendre"], lm_a1) vals["lm_a2"] = (["wavelength", "legendre"], lm_a2) vals["lm_a3"] = (["wavelength", "legendre"], lm_a3) vals["lm_a4"] = (["wavelength", "legendre"], lm_a4) vals["lm_b1"] = (["wavelength", "legendre"], lm_b1) vals["lm_b2"] = (["wavelength", "legendre"], lm_b2) ret_ds = xr.Dataset() ret_ds["xs_scattering"] = vals["xs_scattering"] ret_ds["xs_total"] = vals["xs_total"] ret_ds["lm_a1"] = vals["lm_a1"] ret_ds["lm_a2"] = vals["lm_a2"] ret_ds["lm_a3"] = vals["lm_a3"] ret_ds["lm_a4"] = vals["lm_a4"] ret_ds["lm_b1"] = vals["lm_b1"] ret_ds["lm_b2"] = vals["lm_b2"] return ret_ds refractive = self._refractive_index.refractive_index_fn entries = [] for vals in product(*self._kwargs.values()): psize_args = dict(zip(self._kwargs.keys(), vals, strict=True)) dist = self._psize_dist.distribution(**psize_args) entry = _create_table(dist, refractive, self._wavelengths_nm) for key, val in psize_args.items(): entry[key] = val entries.append(entry) ds = ( xr.concat(entries, dim="args") # noqa: PD010 .set_index(args=list(self._kwargs.keys())) .unstack("args") .rename_dims({"wavelength": "wavelength_nm"}) .rename_vars({"wavelength": "wavelength_nm"}) ) ds.to_netcdf(self._data_file)
[docs] def clear(self): if self._data_file.exists(): self._data_file.unlink()
def path(self, key: str | None = None, **kwargs) -> Path | None: # noqa: ARG002 if not self._data_file.exists(): self.generate() return self._data_file def load_ds(self, key: str | None = None, **kwargs) -> xr.Dataset: return xr.open_dataset(self.path(key, **kwargs))