Source code for torchMWRT.profile

from __future__ import annotations

"""Atmospheric profile container and input-normalization utilities."""

import numpy as np
import torch
import xarray as xr

from .atm import vapor_pressure_from_relative_humidity, vapor_pressure_from_specific_humidity

ProfileInput = torch.Tensor | np.ndarray | xr.DataArray


[docs] class AtmProfile(object): """Atmospheric profile container for torchMWRT radiative transfer. Parameters ---------- temperature : torch.Tensor or numpy.ndarray or xarray.DataArray Atmospheric temperature profile in K. height : torch.Tensor or numpy.ndarray or xarray.DataArray Geometric height profile in m above mean sea level. pressure : torch.Tensor or numpy.ndarray or xarray.DataArray Atmospheric pressure profile in hPa. rh : torch.Tensor or numpy.ndarray or xarray.DataArray, optional Relative humidity as fraction (0-1). q : torch.Tensor or numpy.ndarray or xarray.DataArray, optional Specific humidity in kg/kg. e : torch.Tensor or numpy.ndarray or xarray.DataArray, optional Vapor pressure in hPa. lwc : torch.Tensor or numpy.ndarray or xarray.DataArray, optional Liquid water content in g/m^3. iwc : torch.Tensor or numpy.ndarray or xarray.DataArray, optional Ice water content in g/m^3. emissivity : torch.Tensor or numpy.ndarray or xarray.DataArray, optional Surface emissivity, scalar or frequency-dependent array. Raises ------ ValueError If humidity inputs are not provided exactly once. SystemExit If height is not strictly monotonic. Notes ----- - Expected units: ``height`` in m above mean sea level, ``pressure`` in hPa, ``temperature`` in K, and cloud water contents in g/m^3. - Exactly one humidity input must be provided: relative humidity ``rh``, specific humidity ``q``, or vapor pressure ``e``. - All profile inputs are converted to torch tensors on a common dtype/device. - If vertical levels are descending, the profile is flipped so internal computations use ascending height. Provenance: profile-orientation handling and antenna-relative height normalization are adapted from the profile setup logic in ``pyrtlib.tb_spectrum.TbCloudRTE``. """ def __init__( self, *, temperature: ProfileInput, height: ProfileInput, pressure: ProfileInput, rh: ProfileInput | None = None, q: ProfileInput | None = None, e: ProfileInput | None = None, lwc: ProfileInput | None = None, iwc: ProfileInput | None = None, emissivity: ProfileInput | None = None, ): humidity_inputs = [input_ is not None for input_ in (rh, q, e)] if sum(humidity_inputs) != 1: raise ValueError("Provide exactly one of rh, q, or e.") dtype, device = self._resolve_dtype_device(temperature, height, pressure, rh, q, e, lwc, iwc, emissivity) self.temperature = self._as_profile_tensor(temperature, name="temperature", device=device, dtype=dtype) self.height = self._as_profile_tensor(height, name="height", device=device, dtype=dtype) self.pressure = self._as_profile_tensor(pressure, name="pressure", device=device, dtype=dtype) rh_profile = self._as_profile_tensor(rh, name="rh", device=device, dtype=dtype) if rh is not None else None q_profile = self._as_profile_tensor(q, name="q", device=device, dtype=dtype) if q is not None else None e_profile = self._as_profile_tensor(e, name="e", device=device, dtype=dtype) if e is not None else None self.lwc = self._as_profile_tensor(lwc, name="lwc", device=device, dtype=dtype) if lwc is not None else None self.iwc = self._as_profile_tensor(iwc, name="iwc", device=device, dtype=dtype) if iwc is not None else None self.emissivity_arr = ( self._as_profile_tensor(emissivity, name="emissivity", device=device, dtype=dtype) if emissivity is not None else None ) # Keep batch coordinates for output mapping/dataset. If temperature carries # labeled leading coordinates (e.g. "time"), preserve them. self.batch_coords, self._level_dim = self._infer_batch_coords(temperature) self.z = self.height self.p = self.pressure rh_data = rh_profile q_data = q_profile e_data = e_profile self.denice = self.iwc self.denliq = self.lwc self.tk = self.temperature self.o3n = None self._validate_vertical_shapes(rh_data, q_data, e_data) dz = torch.diff(self.z) if torch.all(dz > 0): pass elif torch.all(dz < 0): self.z = torch.flip(self.z, dims=[0]) self.p = torch.flip(self.p, dims=[-1]) self.tk = torch.flip(self.tk, dims=[-1]) if rh_data is not None: rh_data = torch.flip(rh_data, dims=[-1]) if q_data is not None: q_data = torch.flip(q_data, dims=[-1]) if e_data is not None: e_data = torch.flip(e_data, dims=[-1]) if self.denliq is not None: self.denliq = torch.flip(self.denliq, dims=[-1]) if self.denice is not None: self.denice = torch.flip(self.denice, dims=[-1]) else: raise SystemExit("ERROR: input profile seems incorrect. " "It must be monotonically increasing or decreasing") if rh_data is not None: e_pa = vapor_pressure_from_relative_humidity(self.tk, rh_data) elif q_data is not None: p_pa = self.p * 100.0 e_pa = vapor_pressure_from_specific_humidity(q_data, p_pa) else: # e_data is not None e_pa = e_data * 100.0 self.e = e_pa / 100.0 # hPa self.nl = int(self.z.shape[-1]) self.batch_shape = tuple(self.tk.shape[:-1]) if self.tk.dim() > 1 else () self.ice = False self.z0 = self.z[0] self.z = self.z - self.z0 def _resolve_dtype_device( self, *values: ProfileInput | None, ) -> tuple[torch.dtype, torch.device]: """Determine default dtype/device from provided inputs. Parameters ---------- *values : ProfileInput or None Candidate profile inputs. Returns ------- tuple[torch.dtype, torch.device] Dtype/device inferred from the first torch tensor input, or ``(torch.float64, cpu)`` if none are torch tensors. """ for value in values: if isinstance(value, torch.Tensor): return value.dtype, value.device return torch.float64, torch.device("cpu") def _as_profile_tensor( self, value: ProfileInput, *, name: str, device: torch.device, dtype: torch.dtype, ) -> torch.Tensor: """Convert a profile input to a torch tensor on target dtype/device. Parameters ---------- value : torch.Tensor or numpy.ndarray or xarray.DataArray Input profile data. name : str Variable name used in error messages. device : torch.device Target device. dtype : torch.dtype Target dtype. Returns ------- torch.Tensor Converted tensor. Raises ------ TypeError If input type is not supported. """ if isinstance(value, xr.DataArray): value = value.values if isinstance(value, np.ndarray): if not value.flags.writeable: value = value.copy() tensor = torch.from_numpy(value) elif isinstance(value, torch.Tensor): tensor = value else: raise TypeError( f"{name} must be a torch.Tensor, numpy.ndarray, or xarray.DataArray, got {type(value).__name__}." ) return tensor.to(device=device, dtype=dtype) def _infer_batch_coords(self, ref_var: ProfileInput) -> tuple[dict[str, np.ndarray], str]: """Infer output batch coordinates from the reference input variable. Parameters ---------- ref_var : torch.Tensor or numpy.ndarray or xarray.DataArray Reference variable used to define leading dimensions. Returns ------- tuple[dict[str, numpy.ndarray], str] Batch-coordinate mapping and level-dimension name. Notes ----- The last dimension is interpreted as vertical level; all leading dimensions are propagated as batch dimensions. """ if isinstance(ref_var, xr.DataArray): level_dim = str(ref_var.dims[-1]) if ref_var.ndim > 0 else "level" if ref_var.ndim <= 1: return {}, level_dim coords: dict[str, np.ndarray] = {} for dim_name, size in zip(ref_var.dims[:-1], ref_var.shape[:-1]): if dim_name in ref_var.coords and ref_var.coords[dim_name].size == size: coords[str(dim_name)] = np.asarray(ref_var.coords[dim_name].values) else: coords[str(dim_name)] = np.arange(size) return coords, level_dim ref_shape = tuple(ref_var.shape) if len(ref_shape) <= 1: return {}, "level" coords: dict[str, np.ndarray] = {} for dim_index, size in enumerate(ref_shape[:-1]): coords[f"batch_{dim_index}"] = np.arange(size) return coords, "level" def _load_emissivity(self, nf: int) -> torch.Tensor: """Return emissivity as a tensor ready for model execution. Parameters ---------- nf : int Number of frequencies in the radiative-transfer simulation. Returns ------- torch.Tensor Emissivity tensor. Defaults to ones when emissivity input is not provided. """ if self.emissivity_arr is None: return torch.ones((nf,), device=self.tk.device, dtype=self.tk.dtype) emissivity = self.emissivity_arr.to(device=self.tk.device, dtype=self.tk.dtype) return emissivity def _validate_vertical_shapes(self, rh: torch.Tensor | None, q: torch.Tensor | None, e: torch.Tensor | None) -> None: """Validate that all profile variables share the same vertical length. Parameters ---------- rh : torch.Tensor, optional Relative-humidity profile. q : torch.Tensor, optional Specific-humidity profile. e : torch.Tensor, optional Vapor-pressure profile. Raises ------ ValueError If any variable's last dimension does not match the height profile. """ nl = self.z.shape[-1] expected = { "temperature": self.tk, "pressure": self.p, } if self.denliq is not None: expected["liquid water content"] = self.denliq if self.denice is not None: expected["ice water content"] = self.denice if rh is not None: expected["relative humidity"] = rh if q is not None: expected["specific humidity"] = q if e is not None: expected["vapor pressure"] = e for name, tensor in expected.items(): if tensor.shape[-1] != nl: raise ValueError( f"{name} last dimension must match height profile (expected {nl}, got {tensor.shape[-1]})." )