Source code for sasktran2.engine

from __future__ import annotations

import numpy as np
import xarray as xr

import sasktran2 as sk
from sasktran2._core_rust import PyEngine
from sasktran2.viewinggeo.base import ViewingGeometryContainer


def map_surface_derivative(
    mapping, np_deriv: np.ndarray, dims: list[str]
) -> xr.DataArray:
    if mapping.interpolator is None or len(mapping.interpolator) == 0:
        return xr.DataArray(np_deriv, dims=dims)
    return xr.DataArray(
        np.einsum(
            "ij..., il->lij...",
            np_deriv,
            mapping.interpolator,
            optimize=True,
        ),
        dims=[mapping.interp_dim, *dims],
    )


[docs] class Engine: _engine: PyEngine
[docs] def __init__( self, config: sk.Config, geometry: sk.Geometry1D, viewing_geometry: sk.ViewingGeometry, ): """ An Engine is the main class that handles the radiative transfer calculation. The calculation takes place in two components. First, upon construction of the Engine, the majority of the geometry information is computed and cached. The main calculation takes place when calling :py:meth:`~calculate_radiance` with an :py:class:`sasktran2.Atmosphere` object where the actual radiative transfer calculation is performed. Parameters ---------- config : sk.Config Configuration object model_geometry : sk.Geometry1D Geometry for the model viewing_geo : sk.ViewingGeometry Viewing geometry """ self._engine = PyEngine( config._config, geometry._geometry, viewing_geometry._viewing_geometry ) self._config = config self._geometry = geometry self._viewing_geometry = viewing_geometry
[docs] def calculate_radiance(self, atmosphere: sk.Atmosphere) -> xr.Dataset: """ Performs the radiative transfer calculation for a given atmosphere Parameters ---------- atmosphere : sk.Atmosphere The atmosphere object containing the atmospheric profile and constituents Returns ------- xr.Dataset An xarray dataset containing the radiance and derivatives """ output = self._engine.calculate_radiance(atmosphere.internal_object()) out_ds = xr.Dataset() out_ds["radiance"] = xr.DataArray( output.radiance, dims=["wavelength", "los", "stokes"], ) flux_map = { 0: "upwelling", 1: "downwelling", 2: "actinic", 3: "divergence", } flux_types = [flux_map[int(ft)] for ft in self._config.flux_types] if len(self._viewing_geometry.flux_observers) > 0: # TODO: Grab this from the config for i, flux_type in enumerate(flux_types): out_ds[f"{flux_type}_flux"] = xr.DataArray( output.flux[i], dims=["wavelength", "flux_location"], ) if atmosphere.wavelengths_nm is not None: out_ds.coords["wavelength"] = atmosphere.wavelengths_nm out_ds.coords["stokes"] = ["I", "Q", "U", "V"][: len(out_ds.stokes)] for k, v in output.d_radiance.items(): mapping = atmosphere.storage.get_derivative_mapping(k) name = k if mapping.assign_name == "" else mapping.assign_name if name in out_ds: out_ds[name] += v else: out_ds[name] = xr.DataArray( v, dims=[mapping.interp_dim, "wavelength", "los", "stokes"], ) for k, v in output.d_radiance_surf.items(): mapping = atmosphere.surface.get_derivative_mapping(k) mapped_derivative = map_surface_derivative( mapping, v, ["wavelength", "los", "stokes"] ) if mapping.interp_dim == "dummy": mapped_derivative = mapped_derivative.isel(**{mapping.interp_dim: 0}) out_ds[k] = mapped_derivative for k, v in output.d_flux.items(): mapping = atmosphere.storage.get_derivative_mapping(k) base_name = k if mapping.assign_name == "" else mapping.assign_name for i, flux_type in enumerate(flux_types): name = f"{base_name}_{flux_type}_flux" if name in out_ds: out_ds[name] += v[:, i] else: out_ds[name] = xr.DataArray( v[:, i], dims=[mapping.interp_dim, "wavelength", "flux_location"], ) for k, v in output.d_flux_surf.items(): mapping = atmosphere.surface.get_derivative_mapping(k) base_name = k for i, flux_type in enumerate(flux_types): name = f"{base_name}_{flux_type}_flux" mapped_derivative = map_surface_derivative( mapping, v[i], ["wavelength", "flux_location"] ) if mapping.interp_dim == "dummy": mapped_derivative = mapped_derivative.isel( **{mapping.interp_dim: 0} ) out_ds[name] = mapped_derivative if isinstance(self._viewing_geometry, ViewingGeometryContainer): out_ds = self._viewing_geometry.add_geometry_to_radiance(out_ds) if self._config.output_los_optical_depth: los_od = output.los_optical_depth out_ds["los_optical_depth"] = xr.DataArray( los_od, dims=["wavelength", "los"], ) return out_ds