Source code for sasktran2.constituent.brdf.kokhanovsky
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from sasktran2.mie.refractive import RefractiveIndex
from sasktran2.atmosphere import Atmosphere
from ..base import Constituent
from . import (
PyKokhanovsky,
WavelengthInterpolatorMixin,
)
[docs]
class SnowKokhanovsky(Constituent, WavelengthInterpolatorMixin):
[docs]
def __init__(
self,
L: np.array = 3600000,
M: np.array = 5.5e-8,
refractive_index_fn: RefractiveIndex = None,
wavelengths_nm: np.array = None,
out_of_bounds_mode="zero",
) -> None:
"""
Parameters
----------
L : np.array, optional
Kokhanovsky L parameter, by default 3600000
M : np.array, optional
Kokhanovsky M parameter, by default 5.5e-8
wavelengths_nm : np.array, optional
Wavelengths in [nm] that the parameters L, M is specified at, by default None indicating that L and M are scalar
out_of_bounds_mode : str, optional
One of ["extend" or "zero"], "extend" will extend the last/first value if we are
interpolating outside the grid. "zero" will set the albedo to 0 outside of the
grid boundaries, by default "zero"
"""
Constituent.__init__(self)
WavelengthInterpolatorMixin.__init__(
self,
wavelengths_nm=wavelengths_nm,
wavenumbers_cminv=None,
out_of_bounds_mode=out_of_bounds_mode,
param_length=len(np.atleast_1d(L)),
)
self._L = np.atleast_1d(L)
self._M = np.atleast_1d(M)
if refractive_index_fn is None:
from sasktran2.mie.refractive import Ice
self._refractive_index_fn = Ice()
else:
self._refractive_index_fn = refractive_index_fn
@property
def L(self) -> np.array:
return self._L
@L.setter
def L(self, L: np.array):
self._L = np.atleast_1d(L)
@property
def M(self) -> np.array:
return self._M
@M.setter
def M(self, M: np.array):
self._M = np.atleast_1d(M)
def add_to_atmosphere(self, atmo: Atmosphere):
if atmo.wavelengths_nm is None:
msg = (
"Atmosphere must have wavelengths defined before using SnowKokhonovsky"
)
raise ValueError(msg)
atmo.surface.brdf = PyKokhanovsky(atmo.nstokes)
interp_matrix = self._interpolating_matrix(atmo)
# args(0) is (chi + M) * L / wavelength_nm
# Where chi is the imaginary part of the ice refractive index
chi = -self._refractive_index_fn.refractive_index(atmo.wavelengths_nm).imag
M_interp = interp_matrix @ self._M
L_interp = interp_matrix @ self._L
atmo.surface.brdf_args[0, :] = (chi + M_interp) * L_interp / atmo.wavelengths_nm
def register_derivative(self, atmo: Atmosphere, name: str):
# Start by constructing the interpolation matrix
interp_matrix = self._interpolating_matrix(atmo)
derivs = {}
chi = -self._refractive_index_fn.refractive_index(atmo.wavelengths_nm).imag
# L Deriv factors are (chi + M_interp) / wavelength_nm
L_factor = (chi + (interp_matrix @ self._M)) / atmo.wavelengths_nm
deriv_mapping = atmo.surface.get_derivative_mapping(f"wf_{name}_L")
deriv_mapping.d_brdf[:] += L_factor.reshape(-1, 1)
deriv_mapping.interpolator = interp_matrix
deriv_mapping.interp_dim = f"{name}_wavelength"
deriv_mapping = atmo.surface.get_derivative_mapping(f"wf_{name}_M")
# M Deriv factors are L / wavelength_nm
M_factor = (interp_matrix @ self._L) / atmo.wavelengths_nm
deriv_mapping.d_brdf[:] += M_factor.reshape(-1, 1)
deriv_mapping.interpolator = interp_matrix
deriv_mapping.interp_dim = f"{name}_wavelength"
return derivs