from __future__ import annotations
import logging
import numpy as np
from sasktran2.atmosphere import Atmosphere
from sasktran2.mie.distribution import (
ParticleSizeDistribution,
integrate_mie_cpp,
)
from sasktran2.mie.refractive import RefractiveIndex
from sasktran2.optical.base import OpticalProperty, OpticalQuantities
from sasktran2.polarization import LegendreStorageView
[docs]
class Mie(OpticalProperty):
[docs]
def __init__(
self,
psize_distribution: ParticleSizeDistribution,
refractive_index: RefractiveIndex,
):
"""
Mie scattering optical property where the Mie calculations are done on the fly.
Note, this is much slower than using a precomputed Mie table, and in most cases it is
recommended to use the MieDatabase class instead.
Parameters
----------
psize_distribution : ParticleSizeDistribution
Particle size distribution to use for the Mie calculations
refractive_index : RefractiveIndex
Refraction index to use for the Mie calculations
"""
self._psize_distribution = psize_distribution
self._geometry_dependent = len(psize_distribution.args()) > 0
self._refrac_index = refractive_index
self._calculation_ds = {}
self._xs_ds = {}
def _internal_generate(self, wavelengths_nm, result_ds, **kwargs):
if len(wavelengths_nm) > 20:
logging.warning(
"Calculating Mie scattering parameters for a large number of wavelengths. You may want to use the MieDatabase class instead."
)
dist_args = self._psize_distribution.args()
for arg in dist_args:
if arg not in kwargs:
msg = f"Missing argument {arg} for particle size distribution"
raise ValueError(msg)
if self._geometry_dependent:
# Find the unique values of the distribution arguments
arg_tuples = np.array([kwargs[arg] for arg in dist_args]).T
unique_args = np.unique(arg_tuples, axis=0)
if len(unique_args) > 20:
logging.warning(
"Calculating Mie scattering parameters for a large number of particle size distribution arguments. You may want to use the MieDatabase class instead."
)
dists = [
self._psize_distribution.distribution(
**dict(zip(dist_args, args, strict=True))
)
for args in unique_args
]
ds = integrate_mie_cpp(
dists,
self._refrac_index.refractive_index_fn,
wavelengths_nm,
num_coeffs=kwargs.get("num_legendre", 10),
num_threads=kwargs.get("num_threads", 1),
)
for i, args in enumerate(unique_args):
result_ds[tuple(args)] = ds.isel(distribution=i)
else:
ds = integrate_mie_cpp(
[self._psize_distribution.distribution()],
self._refrac_index.refractive_index_fn,
wavelengths_nm,
num_coeffs=kwargs.get("num_legendre", 10),
num_threads=kwargs.get("num_threads", 1),
)
result_ds[()] = ds.isel(distribution=0)
def atmosphere_quantities(self, atmo: Atmosphere, **kwargs) -> OpticalQuantities:
self._internal_generate(
atmo.wavelengths_nm,
self._calculation_ds,
**{
**kwargs,
"num_legendre": atmo.leg_coeff.a1.shape[0],
"num_threads": atmo._config.num_threads,
},
)
# Then calculate the atmosphere quantities
quants = OpticalQuantities(
extinction=np.zeros(
(len(atmo.model_geometry.altitudes()), len(atmo.wavelengths_nm))
),
ssa=np.zeros(
(len(atmo.model_geometry.altitudes()), len(atmo.wavelengths_nm))
),
)
quants.leg_coeff = np.zeros_like(atmo.storage.leg_coeff)
leg_coeff = LegendreStorageView(quants.leg_coeff, atmo.nstokes)
if self._geometry_dependent:
arg_tuples = np.array(
[kwargs[arg] for arg in self._psize_distribution.args()]
).T
for i, arg in enumerate(arg_tuples):
ds = self._calculation_ds[tuple(arg)].sel(
wavelength_nm=atmo.wavelengths_nm
)
quants.extinction[i, :] = ds["xs_total"].to_numpy()
quants.ssa[i, :] = ds["xs_scattering"].to_numpy()
leg_coeff.a1[:, i, :] = ds["lm_a1"].to_numpy()
if atmo.nstokes == 3:
leg_coeff.a2[:, i, :] = ds["lm_a2"].to_numpy()
leg_coeff.b1[:, i, :] = ds["lm_b1"].to_numpy()
leg_coeff.a3[:, i, :] = ds["lm_a3"].to_numpy()
else:
ds = self._calculation_ds[()].sel(wavelength_nm=atmo.wavelengths_nm)
# Convert nm^2 to m^2
quants.extinction[:] = ds["xs_total"].to_numpy()[np.newaxis, :]
quants.ssa[:] = ds["xs_scattering"].to_numpy()[np.newaxis, :]
leg_coeff.a1[:] = ds["lm_a1"].to_numpy().T[:, np.newaxis, :]
if atmo.nstokes == 3:
leg_coeff.a2[:] = ds["lm_a2"].to_numpy().T[:, np.newaxis, :]
leg_coeff.b1[:] = ds["lm_b1"].to_numpy().T[:, np.newaxis, :]
leg_coeff.a3[:] = ds["lm_a3"].to_numpy().T[:, np.newaxis, :]
quants.extinction[np.isnan(quants.extinction)] = 0
quants.ssa[np.isnan(quants.ssa)] = 0
return quants
def cross_sections(self, wavelengths_nm, altitudes_m, **kwargs):
self._internal_generate(wavelengths_nm, self._xs_ds, **kwargs)
# Then calculate the atmosphere quantities
quants = OpticalQuantities(
extinction=np.zeros((len(altitudes_m), len(wavelengths_nm))),
ssa=np.zeros((len(altitudes_m), len(wavelengths_nm))),
)
if self._geometry_dependent:
arg_tuples = np.array(
[kwargs[arg] for arg in self._psize_distribution.args()]
).T
for i, arg in enumerate(arg_tuples):
ds = self._xs_ds[tuple(arg)].sel(wavelength_nm=wavelengths_nm)
# Convert nm^2 to m^2
quants.extinction[i, :] = ds["xs_total"].to_numpy()
quants.ssa[i, :] = (ds["xs_scattering"].to_numpy()) / quants.extinction[
i, :
]
quants.extinction[np.isnan(quants.extinction)] = 0
quants.ssa[np.isnan(quants.ssa)] = 0
else:
ds = self._xs_ds[()].sel(wavelength_nm=wavelengths_nm)
# Convert nm^2 to m^2
quants.extinction[:] = ds["xs_total"].to_numpy()
quants.ssa[:] = (ds["xs_scattering"].to_numpy()) / quants.extinction
quants.extinction[np.isnan(quants.extinction)] = 0
quants.ssa[np.isnan(quants.ssa)] = 0
return quants