Source code for torchMWRT.rt_model

# -*- coding: utf-8 -*-
# SPDX-License-Identifier: GPL-3.0-only
# Derived from SatCloP/pyrtlib (GPL-3.0): https://github.com/SatCloP/pyrtlib
# Ported to PyTorch by David von Schlebrügge, 2026.
"""
Torch radiative-transfer model utilities.

This module ports pyrtlib radiative-transfer routines to PyTorch and combines
gas/cloud absorption with layer-by-layer radiance integration for batched
profiles, frequencies, and viewing angles.
"""


import warnings
from typing import Tuple, Optional, Union

import numpy as np
import torch
import xarray as xr

from . import consts
from .absorption_model import H2OAbsModel, O2AbsModel, N2AbsModel, LiqAbsModel, O3AbsModel
from .lineshape import ensure_lineshape_data_available

from .profile import AtmProfile


[docs] class RTModel: """Combined radiative-transfer and absorption model. Parameters ---------- freqs : array-like Simulation frequencies in GHz. angles : array-like Viewing elevation angles in degrees. ``90`` denotes zenith-pointing (upward looking). absmdl : str, optional Absorption-model identifier (for example ``"R98"`` or ``"R17"``). ray_tracing : bool, optional If ``True``, use refractivity-dependent ray tracing. If ``False``, use plane-parallel slant-path geometry. from_sat : bool, optional If ``True``, integrate top-down (satellite view). If ``False``, integrate bottom-up (ground-based view). dtype : torch.dtype, optional Default dtype used for internal frequency/angle tensors and absorption model setup. device : str or torch.device, optional Device for internal tensors and absorption-model instances. amu : optional Optional line-shape parameter forwarded to absorption-model constructors. Raises ------ ValueError If ``freqs`` is ``None``. Notes ----- - Absorption-model instances are initialized during construction and reused for subsequent calls to :meth:`execute`. - Cloud absorption is inferred from the presence of ``lwc`` and/or ``iwc`` in the provided :class:`AtmProfile`. - Input ``freqs`` and ``angles`` are converted to torch tensors on the configured dtype/device. Provenance: the overall radiative-transfer and absorption workflow is adapted from pyrtlib (especially ``TbCloudRTE`` and ``RTEquation``) and ported to PyTorch for differentiable workflows. """ def __init__( self, freqs, angles, absmdl: str = "R98", ray_tracing: bool = False, from_sat: bool = False, dtype: torch.dtype = torch.float64, device=None, amu=None, ) -> None: if freqs is None: raise ValueError("RTModel requires freqs to initialise absorption models.") ensure_lineshape_data_available() self.dtype = dtype self.device = torch.device(device) if device is not None else None self.freqs = torch.as_tensor(freqs, dtype=self.dtype, device=self.device) self.angles = torch.as_tensor(angles, dtype=self.dtype, device=self.device) self.absmdl = absmdl self.ray_tracing = ray_tracing self.from_sat = from_sat self._from_sat = from_sat self.h2o = H2OAbsModel(self.absmdl, freqs=self.freqs, dtype=self.dtype, device=self.device, amu=amu) self.o2 = O2AbsModel(self.absmdl, freqs=self.freqs, dtype=self.dtype, device=self.device, amu=amu) self.n2 = N2AbsModel(self.absmdl, freqs=self.freqs, dtype=self.dtype, device=self.device, amu=amu) self.liq = LiqAbsModel(self.absmdl, freqs=self.freqs, dtype=self.dtype, device=self.device, amu=amu) self.o3 = None def _tk2b_mod(self, hvk: torch.Tensor, t: torch.Tensor) -> torch.Tensor: r"""Compute the modified Planck function. Parameters ---------- hvk : torch.Tensor Frequency-dependent factor :math:`h\nu/k_B` in K. t : torch.Tensor Temperature in K. Returns ------- torch.Tensor Modified Planck function :math:`\tilde{B} = (e^{h\nu/(k_B T)} - 1)^{-1}`. Notes ----- This is the torch equivalent of ``pyrtlib.utils.tk2b_mod`` and is kept differentiable for autograd workflows. Provenance: the physical definition and equation framing are taken from the docstring of ``pyrtlib.utils.tk2b_mod``. References ---------- :cite:alp:`Schroeder-Westwater-1991`, Eq. (4) """ return torch.reciprocal(torch.exp(hvk / t) - 1.0)
[docs] def execute( self, profile: "AtmProfile", return_intermediate: bool = False, return_ds: bool = False, ) -> dict[str, torch.Tensor] | xr.Dataset: """Run radiative-transfer integration for one atmospheric profile object. Parameters ---------- profile : AtmProfile Atmospheric profile container with thermodynamic state, cloud fields, geometry, and optional batch coordinates. return_intermediate : bool, optional If ``True``, include intermediate absorption and integrated-path diagnostics in the output. return_ds : bool, optional If ``True``, return an ``xarray.Dataset``; otherwise return a dictionary of tensors. Returns ------- dict[str, torch.Tensor] or xarray.Dataset Brightness-temperature outputs and, optionally, intermediate absorption/path diagnostics. Raises ------ TypeError If ``profile`` is not an ``AtmProfile`` instance. RuntimeError If ray tracing fails or required intermediate arrays are missing. Notes ----- The computation is vectorized over leading batch dimensions, frequency, and viewing angle. Absorption is computed once per frequency/state and reused across angles, then integrated along each slant path. Provenance: output intent and naming conventions are adapted from ``pyrtlib.tb_spectrum.TbCloudRTE.execute`` and the underlying ``pyrtlib.rt_equation.RTEquation`` routines. """ if not isinstance(profile, AtmProfile): raise TypeError( f"profile must be an AtmProfile instance, got {type(profile).__name__}.", ) tk = profile.tk p = profile.p e = profile.e z = profile.z z0 = profile.z0 denliq = profile.denliq denice = profile.denice o3n = profile.o3n has_liq = denliq is not None has_ice = denice is not None has_cloud = has_liq or has_ice frq = self.freqs.to(device=tk.device, dtype=tk.dtype) angles = self.angles.to(device=tk.device, dtype=tk.dtype) nl = profile.nl batch_shape = profile.batch_shape nf = int(frq.shape[0]) nang = int(angles.shape[0]) emissivity = profile._load_emissivity(nf) # compute refractivity dryn, wetn, refindx = self._refractivity(p, tk, e) # Absorption profiles depend only on the atmospheric state and frequency, # not on the look angle; compute once per frequency and reuse for all angles. awet_base, adry_base = self._clearsky_absorption(p, tk, e, o3n) aliq_base = aice_base = None if has_cloud: aliq_base, aice_base = self._cloudy_absorption(tk, denliq, denice) # Compute distance per angle once if self.ray_tracing: ds_all = self._ray_tracing_path(z, refindx, angles, z0) if ds_all is None: raise RuntimeError("ray_tracing_path returned None (invalid refractivity profile).") ds = ds_all else: angle_rad_all = angles * torch.pi / 180.0 amass_all = 1 / torch.sin(angle_rad_all) base_ds = torch.cat( [torch.zeros(1, device=z.device, dtype=z.dtype), torch.diff(z)] ) ds_all = (base_ds.unsqueeze(0) * amass_all.unsqueeze(1)) # (nang, nl) # Broadcast ds across any leading batch dims (e.g. time). ds = ds_all.reshape((1,) * len(batch_shape) + ds_all.shape) # (..., nang, nl) # These intermediates are used only in optional outputs. swet = sdry = sliq = sice = None if return_intermediate: swet, _ = self._exponential_integration(True, wetn.unsqueeze(-2), ds, 0, nl, 0.1) sdry, _ = self._exponential_integration(True, dryn.unsqueeze(-2), ds, 0, nl, 0.1) if has_liq: sliq = self._cloud_integrated_density(denliq.unsqueeze(-2), ds) if has_ice: sice = self._cloud_integrated_density(denice.unsqueeze(-2), ds) # Cache per-angle copies for optional output (broadcast across angle). # awet/adry: (..., nf, nang, nl) awet = adry = None if return_intermediate: awet = awet_base.unsqueeze(-2).expand(*awet_base.shape[:-1], nang, nl) adry = adry_base.unsqueeze(-2).expand(*adry_base.shape[:-1], nang, nl) # Integrate absorption along each slant path: vectorized over batch, angles, and frequencies. # Arrange singleton axes to get frequency-major outputs directly (..., nf, nang[, nl]). ds_abs = ds.unsqueeze(-3) # (..., 1, nang, nl) sptauwet, ptauwet = self._exponential_integration( True, awet_base.unsqueeze(-2), ds_abs, 1, nl, 1 ) sptaudry, ptaudry = self._exponential_integration( True, adry_base.unsqueeze(-2), ds_abs, 1, nl, 1 ) sptauliq = sptauice = ptauliq = ptauice = None ptaulay = ptauwet + ptaudry if aliq_base is not None: sptauliq, ptauliq = self._exponential_integration( False, aliq_base.unsqueeze(-2), ds_abs, 1, nl, 1 ) ptaulay = ptaulay + ptauliq if aice_base is not None: sptauice, ptauice = self._exponential_integration( False, aice_base.unsqueeze(-2), ds_abs, 1, nl, 1 ) ptaulay = ptaulay + ptauice # planck expects (ang, batch..., nf, nl) so that t/emissivity broadcast cleanly. taulay_for_planck = ptaulay.movedim(-2, 0) boftotl, boftatm, boftmr, _, hvk, _, _ = self._planck( tk, taulay_for_planck, emissivity=emissivity, ) tbtotal_ang = self._bright(hvk, boftotl) # (ang, batch..., nf) tbatm_ang = self._bright(hvk, boftatm[..., nl - 1]) tmr_ang = self._bright(hvk, boftmr) # move ang dim to the end -> (batch..., nf, ang) tbtotal = tbtotal_ang.movedim(0, -1) tbatm = tbatm_ang.movedim(0, -1) tmr = tmr_ang.movedim(0, -1) if return_intermediate: if sptauliq is None or sptauice is None: sptauliq = torch.zeros_like(sptauwet) sptauice = torch.zeros_like(sptauwet) if ptauliq is None or ptauice is None: ptauliq = torch.zeros_like(ptauwet) ptauice = torch.zeros_like(ptauwet) if swet is None or sdry is None: raise RuntimeError("Internal error: clear-sky integrated refractivity was not computed.") if sliq is None or sice is None: sliq = torch.zeros_like(swet) sice = torch.zeros_like(swet) if awet is None or adry is None: raise RuntimeError("Internal error: expanded absorption fields were not computed.") if not return_ds: output: dict[str, torch.Tensor] = { "tbtotal": tbtotal, "tbatm": tbatm, "tmr": tmr, } if return_intermediate: output["tauwet"] = sptauwet output["taudry"] = sptaudry output["tauliq"] = sptauliq output["tauice"] = sptauice output["swet"] = swet output["sdry"] = sdry output["sliq"] = sliq output["sice"] = sice output["awet"] = awet output["adry"] = adry output["taulaywet"] = ptauwet output["taulaydry"] = ptaudry output["taulayliq"] = ptauliq output["taulayice"] = ptauice return output def _np(tensor: torch.Tensor) -> np.ndarray: return tensor.detach().cpu().numpy() frq_t = _np(frq) ang_t = _np(angles) batch_dims = list(profile.batch_coords.keys()) fa_dims = (*batch_dims, "frq", "ang") ang_dims = (*batch_dims, "ang") level_dims = (*batch_dims, "frq", "ang", profile._level_dim) coords: dict[str, np.ndarray] = {**profile.batch_coords, "frq": frq_t, "ang": ang_t} if return_intermediate: coords[profile._level_dim] = _np(z) data_vars: dict[str, tuple[tuple[str, ...], np.ndarray]] = { "tbtotal": (fa_dims, _np(tbtotal)), "tbatm": (fa_dims, _np(tbatm)), "tmr": (fa_dims, _np(tmr)), } if return_intermediate: data_vars["tauwet"] = (fa_dims, _np(sptauwet)) data_vars["taudry"] = (fa_dims, _np(sptaudry)) data_vars["tauliq"] = (fa_dims, _np(sptauliq)) data_vars["tauice"] = (fa_dims, _np(sptauice)) data_vars["swet"] = (ang_dims, _np(swet)) data_vars["sdry"] = (ang_dims, _np(sdry)) data_vars["sliq"] = (ang_dims, _np(sliq)) data_vars["sice"] = (ang_dims, _np(sice)) data_vars["awet"] = (level_dims, _np(awet)) data_vars["adry"] = (level_dims, _np(adry)) data_vars["taulaywet"] = (level_dims, _np(ptauwet)) data_vars["taulaydry"] = (level_dims, _np(ptaudry)) data_vars["taulayliq"] = (level_dims, _np(ptauliq)) data_vars["taulayice"] = (level_dims, _np(ptauice)) return xr.Dataset(data_vars=data_vars, coords=coords)
def _bright(self, hvk: torch.Tensor, boft: torch.Tensor) -> torch.Tensor: r"""Convert modified Planck radiance to brightness temperature. Parameters ---------- hvk : torch.Tensor Frequency-dependent factor :math:`h\nu/k_B` in K. boft : torch.Tensor Modified Planck radiance :math:`\tilde{B} = (e^{h\nu/(k_B T)} - 1)^{-1}`. Returns ------- torch.Tensor Brightness temperature in K. Notes ----- For zero radiance values, the output is masked to zero to avoid singularities in ``log(1 + 1/boft)``. Provenance: the inverse-Planck relation is taken from the docstring of ``pyrtlib.rt_equation.RTEquation.bright``. References ---------- :cite:alp:`Schroeder-Westwater-1991`, Eq. (4) """ Tb = hvk / torch.log(1.0 + (1.0 / boft)) Tb = torch.where(boft != 0, Tb, torch.zeros_like(Tb)) return Tb def _refractivity(self, p: torch.Tensor, t: torch.Tensor, e: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute dry/wet refractivity and refractive index profiles. Parameters ---------- p : torch.Tensor Total pressure profile in mbar. t : torch.Tensor Temperature profile in K. e : torch.Tensor Water-vapor partial pressure profile in mbar. Returns ------- tuple[torch.Tensor, torch.Tensor, torch.Tensor] ``(dryn, wetn, refindx)`` where ``dryn`` and ``wetn`` are refractivity components and ``refindx`` is refractive index. Notes ----- The formulation follows Thayer and was originally intended for frequencies below 20 GHz. Provenance: the physical description and intended-frequency note are taken from the docstring of ``pyrtlib.rt_equation.RTEquation.refractivity``. References ---------- :cite:alp:`Thayer-1974` """ p_t, t_t, e_t = torch.broadcast_tensors( torch.as_tensor(p, device=t.device, dtype=t.dtype), torch.as_tensor(t, device=t.device, dtype=t.dtype), torch.as_tensor(e, device=t.device, dtype=t.dtype), ) pa = p_t - e_t tc = t_t - 273.16 tk2 = t_t * t_t tc2 = tc * tc rza = 1.0 + pa * (5.79e-07 * (1.0 + 0.52 / t_t) - (0.00094611 * tc) / tk2) rzw = 1.0 + 1650.0 * (e_t / (t_t * tk2)) * ( 1.0 - 0.01317 * tc + 0.000175 * tc2 + 1.44e-06 * (tc2 * tc) ) wetn = (64.79 * (e_t / t_t) + 377600.0 * (e_t / tk2)) * rzw dryn = 77.6036 * (pa / t_t) * rza refindx = 1.0 + (dryn + wetn) * 1e-06 return dryn, wetn, refindx def _ray_tracing_path( self, z: torch.Tensor, refindx: torch.Tensor, angle: Union[float, torch.Tensor], z0: float, ) -> Union[torch.Tensor, None]: """Compute slant-path layer distances with refractive ray tracing. Parameters ---------- z : torch.Tensor Height profile in meters above observation height ``z0``. refindx : torch.Tensor Refractive-index profile, with last dimension matching ``z``. angle : float or torch.Tensor Elevation angle(s) in degrees. z0 : float Observation height in meters above mean sea level. Returns ------- torch.Tensor or None Slant-path layer lengths in meters. With multiple angles, shape is ``(..., nang, nl)``. With a scalar angle, shape is ``(..., nl)``. Returns ``None`` if refractive index is invalid. Notes ----- This is a vectorized port of the Dutton-Thayer-Westwater algorithm. The method assumes exponential decay of refractivity over each layer. Provenance: the algorithm description and the exponential-layer assumption are taken from the docstring of ``pyrtlib.rt_equation.RTEquation.ray_tracing``. References ---------- :cite:alp:`Bean-Dutton`, Fig. 3.20 and surrounding text """ z_t = torch.as_tensor(z) z_t = z_t / 1000.0 angle_t = torch.as_tensor(angle, dtype=z_t.dtype, device=z_t.device) scalar_out = False if angle_t.dim() == 0: angle_t = angle_t.unsqueeze(0) scalar_out = True ref_t = torch.as_tensor(refindx, device=z_t.device, dtype=z_t.dtype) nl = int(z_t.numel()) if ref_t.shape[-1] != nl: raise ValueError( f"refindx last dimension must match len(z)={nl}; got refindx shape {tuple(ref_t.shape)}." ) batch_shape = tuple(ref_t.shape[:-1]) deg2rad = torch.tensor(torch.pi / 180.0, dtype=z_t.dtype, device=z_t.device) re = torch.tensor(consts.EarthRadius, dtype=z_t.dtype, device=z_t.device) / 1000.0 z0_t = torch.as_tensor(z0, dtype=z_t.dtype, device=z_t.device) / 1000.0 nang = int(angle_t.shape[0]) ds = torch.zeros(batch_shape + (nang, nl), dtype=z_t.dtype, device=z_t.device) # Check for refractive index values that will blow up calculations. if torch.any(ref_t < 1): warnings.warn('ray_tracing: Negative rafractive index') return None # If angle is close to 90 degrees, make ds a height difference profile. mask_90 = ((angle_t >= 89) & (angle_t <= 91)) | ((angle_t >= -91) & (angle_t <= -89)) if torch.any(mask_90): dz = torch.zeros_like(z_t) dz[0] = 0.0 dz[1:] = torch.diff(z_t) ds[..., mask_90, :] = dz # The rest of the subroutine applies only to angle other than 90 degrees. # Convert angle degrees to radians. Initialize constant values. mask_non90 = ~mask_90 if torch.any(mask_non90): theta0 = angle_t[mask_non90] * deg2rad # (nang_non90,) nang_non90 = int(theta0.shape[0]) rs = re + z_t[0] + z0_t rl = re + z_t[0] + z0_t costh0 = torch.cos(theta0) sina = torch.sin(theta0 * 0.5) a0 = 2.0 * (sina ** 2) # Broadcast angle-dependent terms over any batch dims. ang_shape = (1,) * len(batch_shape) + (nang_non90,) theta0_b = theta0.view(ang_shape) costh0_b = costh0.view(ang_shape) sina_b = sina.view(ang_shape) a0_b = a0.view(ang_shape) phil = torch.zeros(batch_shape + (nang_non90,), dtype=z_t.dtype, device=z_t.device) taul = torch.zeros_like(phil) tanthl = torch.tan(theta0_b) one = torch.tensor(1.0, dtype=z_t.dtype, device=z_t.device) ds_non = torch.zeros(batch_shape + (nang_non90, nl), dtype=z_t.dtype, device=z_t.device) for i in range(1, nl): r = re + z_t[i] + z0_t ri = ref_t[..., i].unsqueeze(-1) r_prev = ref_t[..., i - 1].unsqueeze(-1) r0 = ref_t[..., 0].unsqueeze(-1) close = ( torch.isclose(ri, r_prev) | torch.isclose(ri, one) | torch.isclose(r_prev, one) ) denom = torch.where(close, torch.ones_like(ri), ri - one) ratio = (r_prev - one) / denom safe_ratio = torch.where(close, torch.ones_like(ratio), ratio) safe_log = torch.log(safe_ratio) safe_log = torch.where(close, torch.ones_like(safe_log), safe_log) refbar_alt = one + (r_prev - ri) / safe_log refbar = torch.where(close, 0.5 * (ri + r_prev), refbar_alt) argdth = z_t[i] / rs - ((r0 - ri) * costh0_b / ri) argth = 0.5 * (a0_b + argdth) / r if torch.any(argth <= 0): warnings.warn('ray_tracing: Ducting encountered (argth <= 0)') break sint = torch.sqrt(r * argth) theta = 2.0 * torch.asin(sint) cond = theta - 2.0 * theta0_b <= 0.0 dendth = 2.0 * (sint + sina_b) * torch.cos((theta + theta0_b) * 0.25) sind4 = (0.5 * argdth - z_t[i] * argth) / dendth dtheta = torch.where(cond, 4.0 * torch.asin(sind4), theta - theta0_b) theta = theta0_b + dtheta tanth = torch.tan(theta) cthbar = 0.5 * ((1.0 / tanth) + (1.0 / tanthl)) dtau = cthbar * (r_prev - ri) / refbar tau = taul + dtau phi = dtheta + tau ds_i = torch.sqrt( (z_t[i] - z_t[i - 1]) ** 2 + 4.0 * r * rl * (torch.sin((phi - phil) * 0.5) ** 2)) mask_dtau = dtau != 0.0 if torch.any(mask_dtau): dtaua = torch.abs(tau - taul) ds_i = torch.where( mask_dtau, ds_i * (dtaua / (2.0 * torch.sin(dtaua * 0.5))), ds_i ) ds_non[..., i] = ds_i phil = phi taul = tau rl = r tanthl = tanth ds[..., mask_non90, :] = ds_non if scalar_out: return ds[..., 0, :] * 1000.0 return ds * 1000.0 def _exponential_integration(self, zeroflg: bool, x: torch.Tensor, ds: torch.Tensor, ibeg: int, iend: int, factor: float) -> Tuple[torch.Tensor, torch.Tensor]: """Integrate exponentially varying layer quantities. Parameters ---------- zeroflg : bool Zero-handling flag from pyrtlib logic. If ``True``, zero crossings use layer-average values; otherwise they contribute zero. x : torch.Tensor Quantity profile(s) to integrate. ds : torch.Tensor Layer thickness profile(s) in meters. ibeg : int Inclusive lower layer index. iend : int Exclusive upper layer index. factor : float Multiplicative scaling applied to integrated sums. Returns ------- tuple[torch.Tensor, torch.Tensor] ``(sxds, xds)`` where ``sxds`` is the integrated profile sum and ``xds`` holds per-layer contributions. Notes ----- This is a vectorized torch port of pyrtlib's exponential integration. The near-equality threshold is scaled from the original Np/km context to Np/m to preserve branching behavior in this code path. Provenance: algorithm intent and control-flag behavior are taken from the docstring of ``pyrtlib.rt_equation.RTEquation.exponential_integration``. """ x_t = torch.as_tensor(x) ds_t = torch.as_tensor(ds, device=x_t.device, dtype=x_t.dtype) x_t, ds_t = torch.broadcast_tensors(x_t, ds_t) # build shifted version with wrap for ibeg=0 along last dimension x_prev_full = torch.cat([x_t[..., -1:].clone(), x_t[..., :-1]], dim=-1) x_curr = x_t[..., ibeg:iend] x_prev = x_prev_full[..., ibeg:iend] ds_slice = ds_t[..., ibeg:iend] neg_mask = (x_curr < 0.0) | (x_prev < 0.0) if torch.any(neg_mask): warnings.warn('Error encountered in exponential_integration') # pyrtlib's close-value branch uses 1e-9 with absorption in Np/km. # Here absorption is in Np/m, so use the scaled threshold to preserve # the original branching behavior and avoid integration bias. mask_close = torch.abs(x_curr - x_prev) < 1e-12 mask_zero = (x_curr == 0.0) | (x_prev == 0.0) mask_bad = mask_zero | neg_mask safe_prev = torch.where(mask_bad, torch.ones_like(x_prev), x_prev) safe_curr = torch.where(mask_bad, torch.ones_like(x_curr), x_curr) log_ratio = torch.log(safe_curr / safe_prev) log_ratio = torch.where(mask_bad, torch.ones_like(log_ratio), log_ratio) xlayer = (safe_curr - safe_prev) / log_ratio xlayer = torch.where(mask_close, x_curr, xlayer) if zeroflg: xlayer = torch.where(mask_zero, 0.5 * (x_curr + x_prev), xlayer) else: xlayer = torch.where(mask_zero, torch.zeros_like(xlayer), xlayer) xlayer = torch.where(neg_mask, torch.zeros_like(xlayer), xlayer) xds_slice = xlayer * ds_slice xds = torch.zeros_like(x_t) xds[..., ibeg:iend] = xds_slice sxds = torch.sum(xds, dim=-1) * factor return sxds, xds def _cloud_radiating_temperature(self, ibase: float, itop: float, hvk: torch.Tensor, tauprof: torch.Tensor, boftatm: torch.Tensor) -> Union[torch.Tensor, None]: """Compute mean radiating temperature for the lowest cloud layer. Parameters ---------- ibase : float Profile index of cloud base. itop : float Profile index of cloud top. hvk : torch.Tensor Frequency factor :math:`h\nu/k_B`. tauprof : torch.Tensor Integrated absorption profile. boftatm : torch.Tensor Integrated atmospheric modified Planck radiance profile. Returns ------- torch.Tensor or None Mean cloud radiating temperature in K. Notes ----- This routine is designed for the lowest (or only) cloud layer. torchMWRT keeps the original logic but avoids hard early return by clamping exponent arguments when optical depth is large. Provenance: cloud-layer assumptions and formulation description are taken from the docstring of ``pyrtlib.rt_equation.RTEquation.cloud_radiating_temperature``. """ # maximum absolute value for exponential function argument expmax = 125.0 ibase = int(ibase) itop = int(itop) # check if absorption too large to exponentiate tau_base = tauprof[ibase] if torch.any(tau_base > expmax): warnings.warn( 'from cloud_radiating_temperature: absorption too large to exponentiate for tmr of lowest cloud layer') # compute radiance (batmcld) and absorption (taucld) for cloud layer. # (if taucld is too large to exponentiate, treat it as infinity.) batmcld = boftatm[itop] - boftatm[ibase] taucld = tauprof[itop] - tauprof[ibase] tau_base_safe = torch.clamp(tau_base, max=expmax) exp_tau_base = torch.exp(tau_base_safe) taucld_safe = torch.clamp(taucld, max=expmax) exp_neg_taucld = torch.exp(-taucld_safe) denom = 1.0 - exp_neg_taucld boftcld = (batmcld * exp_tau_base) / denom boftcld = torch.where(taucld > expmax, batmcld * exp_tau_base, boftcld) # compute cloud mean radiating temperature (tmrcld) tmrcld = self._bright(hvk, boftcld) return tmrcld def _cloud_integrated_density(self, dencld: torch.Tensor, ds: torch.Tensor) -> torch.Tensor: """Integrate cloud water/ice density along the propagation path. Parameters ---------- dencld : torch.Tensor Cloud condensate density profile in g/m^3. ds : torch.Tensor Layer path lengths in meters. Returns ------- torch.Tensor Path-integrated cloud density in cm (water-equivalent convention). Notes ----- Unlike pyrtlib's version that integrates between explicit cloud base/top indices, this torchMWRT implementation integrates over all layers and supports optional leading batch dimensions. Provenance: base physical quantity and unit convention are taken from the docstring of ``pyrtlib.rt_equation.RTEquation.cloud_integrated_density``. """ den_t = torch.as_tensor(dencld, device=ds.device if isinstance(ds, torch.Tensor) else None, dtype=ds.dtype if isinstance(ds, torch.Tensor) else None) ds_t = torch.as_tensor(ds, device=den_t.device, dtype=den_t.dtype) den_b, ds_b = torch.broadcast_tensors(den_t, ds_t) avg = 0.5 * (den_b + torch.roll(den_b, shifts=1, dims=-1)) term = ds_b * avg scld = torch.sum(term[..., 1:], dim=-1) * 0.1 # skip first to mirror original start return scld def _planck(self, t: torch.Tensor, taulay: torch.Tensor, emissivity: Optional[torch.Tensor] = None) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Compute modified-Planck radiance terms for RTE integration. Parameters ---------- t : torch.Tensor Temperature profile in K, with last dimension ``nl``. taulay : torch.Tensor Layer optical-depth profile with shape ``(..., nf, nl)``. emissivity : torch.Tensor, optional Surface emissivity, scalar or broadcastable to ``(..., nf)``. Returns ------- tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor] ``(boftotl, boftatm, boftmr, tauprof, hvk, boft, bakgrnd)`` where each term follows pyrtlib semantics in torch tensor form. Raises ------ ValueError If input shapes are inconsistent with ``(..., nf, nl)`` handling. Notes ----- The routine computes atmospheric and background contributions for both upwelling and downwelling configurations and preserves the original optical-depth clipping strategy (``expmax = 125``). Provenance: algorithm purpose and returned physical terms are taken from the docstring of ``pyrtlib.rt_equation.RTEquation.planck``. References ---------- :cite:alp:`Schroeder-Westwater-1992`, Eq. (4) """ t_t = torch.as_tensor(t) taulay_t = torch.as_tensor(taulay, device=t_t.device, dtype=t_t.dtype) if taulay_t.dim() < 2: raise ValueError("taulay must have at least 2 dimensions (..., nf, nl).") nl = int(taulay_t.shape[-1]) nf = int(taulay_t.shape[-2]) leading_shape = taulay_t.shape[:-2] if t_t.shape[-1] != nl: raise ValueError( f"Temperature last dimension must match taulay nl={nl}; got t shape {tuple(t_t.shape)} " f"and taulay shape {tuple(taulay_t.shape)}." ) Tc = torch.tensor(consts.Tcosmicbkg, dtype=t_t.dtype, device=t_t.device) h = torch.tensor(consts.planck, dtype=t_t.dtype, device=t_t.device) k = torch.tensor(consts.boltzmann, dtype=t_t.dtype, device=t_t.device) frq_t = self.freqs.to(device=t_t.device, dtype=t_t.dtype) if frq_t.dim() == 0: frq_t = frq_t.reshape(1) if frq_t.numel() != nf: raise ValueError( f"taulay has nf={nf} but RTModel.freqs has {frq_t.numel()} elements." ) # Broadcast temperature to the taulay leading dimensions (angle, batch, ...). while t_t.dim() < len(leading_shape) + 1: t_t = t_t.unsqueeze(0) t_b = t_t.expand(leading_shape + (nl,)) t_bf = t_b.unsqueeze(-2) # (..., 1, nl) -> broadcast across nf # Frequency-dependent constants (broadcast across leading dims and nl). frq_b = frq_t.reshape((1,) * len(leading_shape) + (nf,)) hvk = (frq_b * 1e9) * h / k # (1, ..., 1, nf) hvk = hvk.expand(leading_shape + (nf,)) # (..., nf) hvk_full = hvk.unsqueeze(-1) # (..., nf, 1) expmax = 125.0 # Modified Planck function evaluated at each level. boft = self._tk2b_mod(hvk_full, t_bf) # (..., nf, nl) # Surface emissivity: allow scalar, per-frequency, or batched emissivity. emissivity_t = torch.as_tensor( 1.0 if emissivity is None else emissivity, device=t_t.device, dtype=t_t.dtype, ) # If the last dimension is not a frequency axis, treat emissivity as a scalar # per leading profile and broadcast across frequency. if emissivity_t.dim() != 0 and emissivity_t.shape[-1] not in (1, nf): emissivity_t = emissivity_t.unsqueeze(-1) while emissivity_t.dim() < len(leading_shape) + 1: emissivity_t = emissivity_t.unsqueeze(0) emissivity_f = emissivity_t.expand(leading_shape + (nf,)) tauprof = torch.zeros_like(taulay_t) boftatm = torch.zeros_like(taulay_t) bakgrnd = torch.zeros(leading_shape + (nf,), dtype=t_t.dtype, device=t_t.device) if self._from_sat: boftotl = torch.zeros(leading_shape + (nf,), dtype=t_t.dtype, device=t_t.device) Ts = t_b[..., 0] # (...,) if nl > 1: taulay_tail = taulay_t[..., 1:] tauprof_tail = torch.cumsum(taulay_tail.flip(-1), dim=-1).flip(-1) tauprof = torch.cat([tauprof_tail, torch.zeros_like(taulay_t[..., :1])], dim=-1) exp_tau = torch.exp(-taulay_tail) boftlay = (boft[..., 1:] + boft[..., :-1] * exp_tau) / (1.0 + exp_tau) tauprof_next = tauprof[..., 1:] batmlay = boftlay * torch.exp(-tauprof_next) * (1.0 - exp_tau) boftatm_tail = torch.cumsum(batmlay.flip(-1), dim=-1).flip(-1) boftatm = torch.cat([boftatm_tail, torch.zeros_like(boftatm_tail[..., :1])], dim=-1) else: tauprof.zero_() boftatm.zero_() tau0 = tauprof[..., 0] # (..., nf) boftbg = emissivity_f * self._tk2b_mod(hvk, Ts) + (1 - emissivity_f) * boftotl mask = tau0 < expmax bakgrnd = torch.where(mask, boftbg * torch.exp(-tau0), torch.zeros_like(boftbg)) boftotl = torch.where(mask, bakgrnd + boftatm[..., 0], boftatm[..., 0]) denom = 1.0 - torch.exp(-tau0) boftmr = torch.where(mask, boftatm[..., 0] / denom, boftatm[..., 0]) else: if nl > 1: exp_tau = torch.exp(-taulay_t[..., 1:]) boftlay = (boft[..., :-1] + boft[..., 1:] * exp_tau) / (1.0 + exp_tau) tauprof = torch.cumsum(taulay_t, dim=-1) batmlay = boftlay * torch.exp(-tauprof[..., :-1]) * (1.0 - exp_tau) boftatm_tail = torch.cumsum(batmlay, dim=-1) boftatm = torch.cat([torch.zeros_like(boftatm_tail[..., :1]), boftatm_tail], dim=-1) else: tauprof.zero_() boftatm.zero_() tau_last = tauprof[..., nl - 1] # (..., nf) boftbg = self._tk2b_mod(hvk, Tc) mask = tau_last < expmax bakgrnd = torch.where(mask, boftbg * torch.exp(-tau_last), torch.zeros_like(tau_last)) boftatm_last = boftatm[..., nl - 1] boftotl = torch.where(mask, bakgrnd + boftatm_last, boftatm_last) denom = 1.0 - torch.exp(-tau_last) boftmr = torch.where(mask, boftatm_last / denom, boftatm_last) return boftotl, boftatm, boftmr, tauprof, hvk, boft, bakgrnd def _cloudy_absorption( self, t: torch.Tensor, denl: Optional[torch.Tensor], deni: Optional[torch.Tensor], ) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: """Compute cloud liquid and ice absorption profiles. Parameters ---------- t : torch.Tensor Temperature profile in K. denl : torch.Tensor, optional Cloud liquid water density in g/m^3. deni : torch.Tensor, optional Cloud ice density in g/m^3. Returns ------- tuple[torch.Tensor or None, torch.Tensor or None] ``(aliq, aice)`` cloud liquid and ice absorption in Np/m. Raises ------ ValueError If absorption models are not initialized. Notes ----- The formulation follows pyrtlib cloud absorption components, then converts from Np/km to Np/m. Negative values are warned and clamped to zero for numerical robustness. Provenance: physical model intent is taken from the docstring of ``pyrtlib.rt_equation.RTEquation.cloudy_absorption``. References ---------- :cite:alp:`Westwater-1972` See Also -------- pyrtlib.absorption_model.LiqAbsModel.liquid_water_absorption """ if denl is None and deni is None: return None, None ref = next((arg for arg in (t, denl, deni) if isinstance(arg, torch.Tensor)), None) t_t = torch.as_tensor(t, device=ref.device if ref is not None else None, dtype=ref.dtype if ref is not None else None) denl_t = torch.as_tensor(denl, device=t_t.device, dtype=t_t.dtype) if denl is not None else None deni_t = torch.as_tensor(deni, device=t_t.device, dtype=t_t.dtype) if deni is not None else None if denl_t is not None and deni_t is not None: t_t, denl_t, deni_t = torch.broadcast_tensors(t_t, denl_t, deni_t) elif denl_t is not None: t_t, denl_t = torch.broadcast_tensors(t_t, denl_t) elif deni_t is not None: t_t, deni_t = torch.broadcast_tensors(t_t, deni_t) if self.liq is None: raise ValueError("RTModel absorption models not initialised.") batch_shape = t_t.shape[:-1] frq_t = self.freqs.to(device=t_t.device, dtype=t_t.dtype) nf = int(frq_t.numel()) frq_grid = frq_t.reshape((1,) * len(batch_shape) + (nf, 1)) # (..., nf, 1) c = torch.tensor(consts.light * 100, dtype=t_t.dtype, device=t_t.device) ghz2hz = torch.as_tensor(1e9, dtype=t_t.dtype, device=t_t.device) db2np = torch.log(torch.tensor(10.0, dtype=t_t.dtype, device=t_t.device)) * 0.1 t_exp = t_t.unsqueeze(-2) # (..., 1, nl) wave = c / (frq_grid * ghz2hz) # (..., nf, 1) -> broadcast to (..., nf, nl) aliq = aice = None if denl_t is not None: denl_exp = denl_t.unsqueeze(-2) # (..., 1, nl) liq_abs = self.liq.liquid_water_absorption(denl_exp, frq_grid, t_exp) has_liq = denl_exp > 0 aliq = torch.where(has_liq, liq_abs, torch.zeros_like(liq_abs)) if deni_t is not None: deni_exp = deni_t.unsqueeze(-2) # (..., 1, nl) ice_abs = ((8.18645 / wave) * deni_exp) * 0.000959553 * db2np has_ice = deni_exp > 0 aice = torch.where(has_ice, ice_abs, torch.zeros_like(ice_abs)) # clamp any negative absorption early and signal clearly neg_liq = aliq < 0 if aliq is not None else None neg_ice = aice < 0 if aice is not None else None has_neg_liq = bool(torch.any(neg_liq)) if neg_liq is not None else False has_neg_ice = bool(torch.any(neg_ice)) if neg_ice is not None else False if has_neg_liq or has_neg_ice: liq_min = float(aliq.min().detach().cpu()) if has_neg_liq and aliq is not None else 0.0 ice_min = float(aice.min().detach().cpu()) if has_neg_ice and aice is not None else 0.0 warnings.warn( f"Negative cloud absorption detected (liq_min={liq_min:.3e}, ice_min={ice_min:.3e}); clamping to zero.", ) if aliq is not None: aliq = aliq.clamp_min(0.0) if aice is not None: aice = aice.clamp_min(0.0) return ( aliq / 1000.0 if aliq is not None else None, aice / 1000.0 if aice is not None else None, ) def _clearsky_absorption(self, p: torch.Tensor, t: torch.Tensor, e: torch.Tensor, o3n: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """Compute clear-sky water-vapor and dry-air absorption profiles. Parameters ---------- p : torch.Tensor Pressure profile in mbar. t : torch.Tensor Temperature profile in K. e : torch.Tensor Water-vapor partial pressure profile in mbar. o3n : torch.Tensor, optional Ozone number density profile in molecules/m^3. Returns ------- tuple[torch.Tensor, torch.Tensor] ``(awet, adry)`` in Np/m. Raises ------ ValueError If required absorption models are not initialized. Notes ----- ``awet`` is computed from water-vapor line and continuum terms. ``adry`` combines :math:`O_2`, :math:`N_2`, and optional :math:`O_3`. Internal ppm outputs from absorption-model routines are converted to Np/m in this method. Provenance: base clear-sky formulation statement is taken from the docstring of ``pyrtlib.rt_equation.RTEquation.clearsky_absorption``. References ---------- :cite:alp:`Liebe-Layton` :cite:alp:`Rosenkranz-1988` See Also -------- pyrtlib.absorption_model.H2OAbsModel.h2o_absorption pyrtlib.absorption_model.O2AbsModel.o2_absorption """ h2o_model = self.h2o o2_model = self.o2 if h2o_model is None or o2_model is None: raise ValueError("RTModel absorption models not initialised.") n2_model = self.n2 o3_model = self.o3 ref = next((arg for arg in (p, t, e) if isinstance(arg, torch.Tensor)), None) p_t = torch.as_tensor(p, device=ref.device if ref is not None else None, dtype=ref.dtype if ref is not None else None) t_t = torch.as_tensor(t, device=p_t.device, dtype=p_t.dtype) e_t = torch.as_tensor(e, device=p_t.device, dtype=p_t.dtype) p_t, t_t, e_t = torch.broadcast_tensors(p_t, t_t, e_t) batch_shape = p_t.shape[:-1] frq_t = self.freqs.to(device=p_t.device, dtype=p_t.dtype) nf = int(frq_t.numel()) frq_grid = frq_t.reshape((1,) * len(batch_shape) + (nf, 1)) # (..., nf, 1) factor = 0.182 * frq_grid db2np = torch.log(torch.tensor(10.0, dtype=p_t.dtype, device=p_t.device)) * 0.1 ekpa_lvl = e_t / 10.0 pdrykpa_lvl = p_t / 10.0 - ekpa_lvl pdrykpa = pdrykpa_lvl.unsqueeze(-2) # (..., 1, nl) vx = (300.0 / t_t).unsqueeze(-2) # (..., 1, nl) ekpa = ekpa_lvl.unsqueeze(-2) # (..., 1, nl) npp, ncpp = h2o_model.h2o_absorption(pdrykpa, vx, ekpa, frq_grid) awet = factor * (npp + ncpp) * db2np npp, ncpp = o2_model.o2_absorption(pdrykpa, vx, ekpa, frq_grid) aO2 = factor * (npp + ncpp) * db2np if n2_model: aN2 = n2_model.n2_absorption(t_t.unsqueeze(-2), (pdrykpa_lvl * 10.0).unsqueeze(-2), frq_grid) else: aN2 = torch.zeros_like(aO2) if isinstance(o3n, torch.Tensor) and o3_model and hasattr(o3_model, "o3_absorption"): o3n_t = torch.as_tensor(o3n, device=p_t.device, dtype=p_t.dtype) o3n_t = torch.broadcast_to(o3n_t, p_t.shape).unsqueeze(-2) aO3 = o3_model.o3_absorption(t_t.unsqueeze(-2), p_t.unsqueeze(-2), frq_grid, o3n_t) else: aO3 = torch.zeros_like(aO2) adry = aO2 + aN2 + aO3 return awet / 1000.0, adry / 1000.0