Source code for gdptools.data.user_data

"""User data classes for standardizing input data across different sources.

This module provides a comprehensive set of classes for handling various types of
geospatial data sources in the gdptools package. These classes serve as data
containers that standardize inputs from different sources (STAC catalogs, local
files, web services) and ensure all necessary metadata is available for processing.

The module implements an abstract base class pattern with specialized subclasses
for different data source types, providing a consistent interface for data
preparation, spatial subsetting, and aggregation operations.

Classes:
    UserData: Abstract base class defining the common interface for all data sources.
    ClimRCatData: Interface for Climate-R catalog datasets with automatic metadata handling.
    UserCatData: Handler for user-provided xarray datasets with custom configuration.
    NHGFStacData: Interface for NHGF STAC catalog datasets with spatiotemporal filtering.
    UserTiffData: Specialized handler for GeoTIFF and raster data processing.

Examples:
    Using Climate-R catalog data:

    .. code-block:: python

        from gdptools.data.user_data import ClimRCatData
        import pandas as pd

        # Load catalog and create data source
        cat_url = "https://github.com/mikejohnson51/climateR-catalogs/releases/download/June-2024/catalog.parquet"
        cat = pd.read_parquet(cat_url)
        cat_dict = {"temp": cat.query("id == 'gridmet' & variable == 'tmmn'").to_dict("records")[0]}

        data = ClimRCatData(
            source_cat_dict=cat_dict,
            target_gdf="watersheds.shp",
            target_id="huc12",
            source_time_period=["2020-01-01", "2020-12-31"]
        )

    Using custom xarray dataset:

    .. code-block:: python

        from gdptools.data.user_data import UserCatData
        import xarray as xr

        # Load custom dataset
        ds = xr.open_dataset("climate_data.nc")

        data = UserCatData(
            source_ds=ds,
            source_crs=4326,
            source_x_coord="longitude",
            source_y_coord="latitude",
            source_t_coord="time",
            source_var=["temperature", "precipitation"],
            target_gdf="regions.shp",
            target_crs=4326,
            target_id="region_id",
            source_time_period=["2020-01-01", "2020-12-31"]
        )

Notes:
    All classes automatically handle coordinate reference system validation,
    spatial intersection checking, and temporal subsetting to ensure data
    compatibility for downstream processing operations.

"""

import datetime
import logging
import time
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Union

if TYPE_CHECKING:
    import pystac

import geopandas as gpd
import numpy as np
import pandas as pd
import xarray as xr
from pyproj import CRS

from gdptools.data.agg_gen_data import AggData
from gdptools.data.odap_cat_data import CatClimRItem
from gdptools.data.weight_gen_data import WeightData
from gdptools.depreciation_utils import deprecate_kwargs
from gdptools.helpers import build_subset, build_subset_tiff
from gdptools.utils import (
    _assert_period_intersects_time_axis,
    _check_for_intersection,
    _check_for_intersection_nc,
    # _get_cells_poly,
    _get_data_via_catalog,
    _get_rxr_dataset,
    _get_shp_bounds_w_buffer,
    _get_shp_file,
    _get_top_to_bottom,
    _get_xr_dataset,
    _is_valid_crs,
    _process_period,
    _read_shp_file,
    _sel_with_time_diagnostic,
)
from gdptools.utils_optimized import _get_cells_poly_fast

logger = logging.getLogger(__name__)


def _rotate_longitude_if_needed(
    ds: xr.Dataset,
    x_coord: str,
    *,
    should_rotate: bool,
) -> xr.Dataset:
    """Return ``ds`` with longitude rotated into -180..180 and sorted.

    When ``should_rotate`` is True, a new Dataset is returned with its
    longitude coordinate wrapped into -180..180 and sorted ascending.
    Otherwise ``ds`` is returned unchanged.

    Never mutates the input Dataset. Both ``assign_coords`` and
    ``sortby`` return new Dataset objects, so the caller's reference is
    left untouched.

    Args:
        ds: Source Dataset that may have 0..360 longitude.
        x_coord: Name of the longitude coordinate on ``ds``.
        should_rotate: True iff rotation is required. Computed by the
            caller from ``_check_for_intersection_nc`` flags.

    Returns:
        A new Dataset in -180..180 sorted form if rotation was needed;
        otherwise ``ds`` unchanged.
    """
    if not should_rotate:
        return ds
    logger.info("  - rotating into -180 - 180")
    return ds.assign_coords(
        {x_coord: (ds.coords[x_coord] + 180) % 360 - 180}
    ).sortby(x_coord)


[docs] class UserData(ABC): """Abstract base class for standardizing geospatial data inputs. This class defines the common interface that all data source classes must implement. It ensures consistent data preparation, spatial subsetting, and aggregation capabilities across different data sources (catalogs, files, web services). The class enforces a standardized workflow for data handling: 1. Data loading and validation 2. Coordinate reference system handling 3. Spatial and temporal subsetting 4. Data preparation for weight generation and aggregation All subclasses must implement the abstract methods to provide source-specific functionality while maintaining interface consistency. Notes: This is an abstract base class and cannot be instantiated directly. Use one of the concrete subclasses (ClimRCatData, UserCatData, NHGFStacData, UserTiffData) instead. """
[docs] @abstractmethod def __init__(self) -> None: """Initialize the data source. This method must be implemented by subclasses to handle source-specific initialization requirements. """ pass
[docs] @abstractmethod def get_target_crs(self) -> CRS: """Get the coordinate reference system of the target geometries. Returns: The coordinate reference system of the target vector data. """ pass
[docs] @abstractmethod def get_source_subset(self, key: str) -> xr.DataArray: """Get a spatially and temporally subset of the source data. Args: key (str): Variable name or identifier for the data to subset. Returns: A spatially and temporally subset DataArray for the specified variable. """ pass
[docs] @abstractmethod def prep_wght_data(self) -> WeightData: """Prepare data for weight generation calculations. Returns: A WeightData instance containing the necessary data for calculating spatial intersection weights between source and target geometries. """ pass
[docs] @abstractmethod def prep_interp_data(self, key: str, poly_id: Union[str, int]) -> AggData: """Prepare data for interpolation operations. Args: key (str): Variable name or identifier for the data to prepare. poly_id (Union[str, int]): Identifier for the target polygon geometry. Returns: An AggData instance configured for interpolation operations. """ pass
[docs] @abstractmethod def prep_agg_data(self, key: str) -> AggData: """Prepare data for aggregation operations. Args: key (str): Variable name or identifier for the data to prepare. Returns: An AggData instance configured for aggregation operations. """ pass
[docs] @abstractmethod def get_vars(self) -> list[str]: """Get the list of available variables in the data source. Returns: list[str]: List of variable names available for processing. """ pass
[docs] @abstractmethod def get_feature_id(self) -> str: """Get the identifier column name for target geometries. Returns: str: The column name used as the unique identifier for target geometries. """ pass
[docs] @abstractmethod def get_class_type(self) -> str: """Get the type identifier for this data source class. Returns: str: A string identifier for the data source type (e.g., "ClimRCatData"). """ pass
[docs] class ClimRCatData(UserData): """Interface for Climate-R catalog datasets with automatic metadata handling. This class provides seamless integration with the Climate-R catalog system, developed by Mike Johnson and available at https://github.com/mikejohnson51/climateR-catalogs. It automatically handles metadata extraction, coordinate system detection, and spatial/temporal subsetting for catalog-based datasets. The class accepts Climate-R catalog dictionaries and automatically configures data access parameters, coordinate names, and temporal ranges based on the catalog metadata. Attributes: source_cat_dict: Dictionary mapping variable names to Climate-R catalog metadata. target_gdf: GeoDataFrame containing target geometries for spatial operations. target_id: Column name for unique identifiers in target geometries. source_time_period: Processed time period for temporal subsetting. Examples: Basic usage with Climate-R catalog: .. code-block:: python import pandas as pd from gdptools.data.user_data import ClimRCatData # Load Climate-R catalog cat_url = "https://github.com/mikejohnson51/climateR-catalogs/releases/download/June-2024/catalog.parquet" cat = pd.read_parquet(cat_url) # Create catalog dictionary for TerraClimate variables cat_vars = ["aet", "pet", "PDSI"] cat_params = [ cat.query("id == 'terraclim' & variable == @var").to_dict("records")[0] for var in cat_vars ] source_cat_dict = dict(zip(cat_vars, cat_params)) # Initialize data source data = ClimRCatData( source_cat_dict=source_cat_dict, target_gdf="watersheds.shp", target_id="huc12", source_time_period=["2020-01-01", "2020-12-31"] ) # Access data variables = data.get_vars() subset = data.get_source_subset("aet") Multiple variables from GridMET: .. code-block:: python # GridMET temperature and precipitation gm_vars = ["tmmn", "tmmx", "pr"] gm_params = [ cat.query("id == 'gridmet' & variable == @var").to_dict("records")[0] for var in gm_vars ] source_cat_dict = dict(zip(gm_vars, gm_params)) data = ClimRCatData( source_cat_dict=source_cat_dict, target_gdf=target_polygons, target_id="poly_id", source_time_period=["2019-01-01", "2021-12-31"] ) """ _deprecation_map = { # noqa: RUF012 "cat_dict": "source_cat_dict", "f_feature": "target_gdf", "id_feature": "target_id", "period": "source_time_period", }
[docs] @deprecate_kwargs(_deprecation_map, removed_in="1.0.0") def __init__( # noqa: D417 self: "ClimRCatData", *, source_cat_dict: dict[str, dict[str, Any]], target_gdf: Union[str, Path, gpd.GeoDataFrame], target_id: str, source_time_period: list[Union[str, pd.Timestamp, datetime.datetime] | None], **kwargs, # noqa: ANN003 ) -> None: # sourcery skip: simplify-len-comparison """Initialize ClimRCatData for Climate-R catalog integration. Sets up data access for Climate-R catalog datasets with automatic metadata handling, coordinate system detection, and spatial/temporal validation. Args: source_cat_dict: Dictionary mapping variable names to Climate-R catalog metadata dictionaries. Each entry should contain the complete catalog record for a variable, including URL, coordinate names, CRS, and temporal information. target_gdf: GeoDataFrame containing target geometries, or path to a shapefile/GeoPackage that can be read by geopandas. target_id: Column name in target_gdf to use as unique identifier for geometries in weight calculations and aggregations. source_time_period: Two-element list defining the temporal range for data subsetting. Format: ["YYYY-MM-DD", "YYYY-MM-DD"] or ["YYYY-MM-DD HH:MM:SS", "YYYY-MM-DD HH:MM:SS"]. Selection is label-based and inclusive on both ends (xarray ``.sel(slice(...))`` semantics). For coarse time resolutions (monthly, 8-day composites) the period must include at least one label; a window that falls between labels produces an empty slice and raises ``ValueError``. For sub-daily sources specified with date-only strings, xarray treats the end date as ``T00:00:00`` — the final day's afternoon is excluded. Raises: KeyError: If target_id is not found in target_gdf columns. ValueError: If source_cat_dict is empty or contains invalid entries. TypeError: If catalog entries are missing required metadata fields. Notes: The catalog dictionary structure should match the Climate-R catalog format with fields like 'URL', 'X_name', 'Y_name', 'T_name', 'crs', 'toptobottom', etc. Invalid or incomplete entries will raise errors during initialization. Examples: Initialize with TerraClimate data: .. code-block:: python import pandas as pd cat_url = ( "https://github.com/mikejohnson51/climateR-catalogs/releases/download/June-2024/catalog.parquet" ) cat = pd.read_parquet(cat_url) # Create catalog entry for actual evapotranspiration aet_params = ( cat.query("id == 'terraclim' & variable == 'aet'").to_dict("records")[0] ) source_cat_dict = {"aet": aet_params} data = ClimRCatData( source_cat_dict=source_cat_dict, target_gdf="watersheds.shp", target_id="huc12", source_time_period=["2020-01-01", "2020-12-31"], ) """ logger.info("Initializing ClimRCatData") logger.info(" - loading data") self.source_cat_dict = source_cat_dict self.target_gdf = target_gdf self.target_id = target_id self.source_time_period = _process_period(source_time_period) self._check_input_dict() self._gdf = _read_shp_file(self.target_gdf) if self.target_id not in self._gdf.columns: # print( # f"target_id {self.target_id} not in target_gdf columns: " # f" {self._gdf.columns}" # ) raise KeyError(f"target_id {self.target_id} not in target_gdf columns: {self._gdf.columns}") self.target_crs = self._gdf.crs self._check_id_feature() self._keys = self.get_vars() cat_cr = self._create_climrcats(key=self._keys[0]) logger.info(" - checking latitude bounds") is_intersect, is_degrees, is_0_360 = _check_for_intersection( cat_cr=cat_cr, gdf=self._gdf ) # Project the gdf to the crs of the gridded data self._gdf, self._gdf_bounds = _get_shp_file(shp_file=self._gdf, cat_cr=cat_cr, is_degrees=is_degrees) self._rotate_ds = bool((not is_intersect) & is_degrees & (is_0_360))
[docs] def get_target_crs(self) -> CRS: """Get the coordinate reference system of the target geometries. Returns: CRS: The coordinate reference system of the target vector data. """ return self.target_crs
[docs] def get_source_subset(self, key: str) -> xr.DataArray: """Get a spatially and temporally subset of the source data. This method uses the metadata from the Climate-R catalog to retrieve and subset the data for a specific variable. Args: key (str): The variable name to subset from the catalog. Returns: xr.DataArray: A subsetted xarray DataArray. """ cat_cr = self._create_climrcats(key=key) return _get_data_via_catalog( cat_cr=cat_cr, bounds=self._gdf_bounds, begin_date=self.source_time_period[0], end_date=self.source_time_period[1], rotate_lon=self._rotate_ds, )
[docs] def prep_interp_data(self, key: str, poly_id: Union[str, int]) -> AggData: """Prep AggData from ClimRCatData. Args: key (str): Name of the xarray gridded data variable. poly_id (Union[str, int]): ID number of the geodataframe geometry to clip the gridded data to. Returns: AggData: An instance of the AggData class, ready for interpolation. """ cat_cr = self._create_climrcats(key=key) # Select a feature and make sure it remains a geodataframe target_gdf = self._gdf[self._gdf[self.target_id] == poly_id] # Clip grid by x, y and time ds_ss = _get_data_via_catalog( cat_cr=cat_cr, bounds=self._gdf_bounds, begin_date=self.source_time_period[0], end_date=self.source_time_period[1], rotate_lon=self._rotate_ds, ) return AggData( variable=key, cat_cr=cat_cr, da=ds_ss, target_gdf=target_gdf, target_id=self.target_id, source_time_period=self.source_time_period, )
[docs] def prep_agg_data(self, key: str) -> AggData: """Prepare ClimRCatData data for aggregation methods. Args: key (str): The variable name to prepare for aggregation. Returns: AggData: An AggData instance ready for aggregation. """ cat_cr = self._create_climrcats(key=key) target_gdf = self._gdf ds_ss = _get_data_via_catalog( cat_cr=cat_cr, bounds=self._gdf_bounds, begin_date=self.source_time_period[0], end_date=self.source_time_period[1], rotate_lon=self._rotate_ds, ) return AggData( variable=key, cat_cr=cat_cr, da=ds_ss, target_gdf=target_gdf, target_id=self.target_id, source_time_period=self.source_time_period, )
[docs] def prep_wght_data(self) -> WeightData: """Prepare data for weight generation calculations. Returns: WeightData: Data required for calculating spatial intersection weights between source and target geometries. """ cat_cr = self._create_climrcats(key=self._keys[0]) ds_ss = _get_data_via_catalog( cat_cr=cat_cr, bounds=self._gdf_bounds, begin_date=self.source_time_period[0], end_date=self.source_time_period[1], rotate_lon=self._rotate_ds, ) tsrt = time.perf_counter() gdf_grid = _get_cells_poly_fast( xr_a=ds_ss, x=cat_cr.X_name, y=cat_cr.Y_name, crs_in=cat_cr.crs, ) tend = time.perf_counter() logger.info("Grid cells generated in %0.4f seconds", tend - tsrt) return WeightData(target_gdf=self._gdf, target_id=self.target_id, grid_cells=gdf_grid)
[docs] def get_feature_id(self) -> str: """Return target_id.""" return self.target_id
[docs] def get_vars(self) -> list[str]: """Return list of source_cat_dict keys, proxy for varnames.""" return list(self.source_cat_dict.keys())
def _check_input_dict(self: "ClimRCatData") -> None: """Check input source_cat_dict for content.""" if len(self.source_cat_dict) < 1: raise ValueError("source_cat_dict should have at least 1 key,value pair") def _check_id_feature(self: "ClimRCatData") -> None: """Check if target_id is in the GeoDataFrame columns.""" if self.target_id not in self._gdf.columns[:]: raise ValueError(f"shp_poly_idx: {self.target_id} not in gdf columns: {self._gdf.columns}")
[docs] def get_class_type(self) -> str: """Get the type identifier for this data source class.""" return "ClimRCatData"
def _create_climrcats(self: "ClimRCatData", key: str) -> CatClimRItem: """Create a CatClimRItem instance from the catalog dictionary. Args: key (str): The variable key to look up in the source catalog dictionary. Returns: CatClimRItem: A dataclass instance containing the metadata for the variable. """ return CatClimRItem(**self.source_cat_dict[key])
[docs] class UserCatData(UserData): """Handler for user-provided xarray datasets with custom configuration. This class provides a flexible interface for working with user-supplied gridded datasets that are not available through catalogs. It handles xarray datasets from local files, URLs, or in-memory objects, with full control over coordinate names, variable selection, and coordinate reference systems. The class performs comprehensive validation of input parameters and data compatibility, ensuring that coordinate names exist, variables are available, and coordinate reference systems are valid. Attributes: source_ds: The source xarray Dataset containing gridded data. target_gdf: GeoDataFrame containing target geometries. target_id: Column name for unique identifiers in target geometries. source_time_period: Processed time period for temporal subsetting. Examples: Basic usage with local NetCDF file: .. code-block:: python import xarray as xr from gdptools.data.user_data import UserCatData # Load custom dataset ds = xr.open_dataset("climate_data.nc") data = UserCatData( source_ds=ds, source_crs=4326, source_x_coord="longitude", source_y_coord="latitude", source_t_coord="time", source_var=["temperature", "precipitation"], target_gdf="watersheds.shp", target_crs=4326, target_id="huc12", source_time_period=["2020-01-01", "2020-12-31"] ) Using URL data source: .. code-block:: python # Remote dataset data = UserCatData( source_ds="https://example.com/climate.nc", source_crs="EPSG:4326", source_x_coord="lon", source_y_coord="lat", source_t_coord="time", source_var="temperature", target_gdf=polygons_gdf, target_crs=3857, target_id="poly_id", source_time_period=["2019-01-01", "2021-12-31"] ) Multiple variables with different CRS: .. code-block:: python # Projected dataset data = UserCatData( source_ds="projected_data.nc", source_crs=3857, # Web Mercator source_x_coord="x", source_y_coord="y", source_t_coord="time", source_var=["var1", "var2", "var3"], target_gdf="regions.gpkg", target_crs=4326, target_id="region_id", source_time_period=["2020-06-01", "2020-08-31"] ) """ _deprecation_map = { # noqa: RUF012 "ds": "source_ds", "proj_ds": "source_crs", "x_coord": "source_x_coord", "y_coord": "source_y_coord", "t_coord": "source_t_coord", "var": "source_var", "f_feature": "target_gdf", "proj_feature": "target_crs", "id_feature": "target_id", "period": "source_time_period", }
[docs] @deprecate_kwargs(_deprecation_map, removed_in="1.0.0") def __init__( self: "UserCatData", *, source_ds: Union[str, xr.Dataset], source_crs: str | int | CRS, source_x_coord: str, source_y_coord: str, source_t_coord: str, source_var: Union[str, list[str]], target_gdf: Union[str, Path, gpd.GeoDataFrame], target_crs: str | int | CRS, target_id: str, source_time_period: list[Union[str, pd.Timestamp, datetime.datetime] | None], ) -> None: """Initialize UserCatData for custom gridded datasets. Sets up data access for user-provided xarray datasets with comprehensive validation of coordinates, variables, and coordinate reference systems. Args: source_ds: Source dataset as xarray Dataset, file path, or URL. Can be any data source readable by xarray.open_dataset(). source_crs: Coordinate reference system for the source dataset. Can be EPSG code, proj4 string, WKT, or pyproj CRS object. source_x_coord: Name of the x-coordinate dimension in source_ds. Must exist in dataset coordinates. source_y_coord: Name of the y-coordinate dimension in source_ds. Must exist in dataset coordinates. source_t_coord: Name of the time coordinate dimension in source_ds. Must exist in dataset coordinates. source_var: Variable name(s) to use for processing. Can be a single string or list of strings. All variables must exist in source_ds. target_gdf: Target geometries as GeoDataFrame or file path. Can be any format readable by geopandas.read_file(). target_crs: Coordinate reference system for target geometries. Can be EPSG code, proj4 string, WKT, or pyproj CRS object. target_id: Column name in target_gdf to use as unique identifier. Must exist in target_gdf columns. source_time_period: Two-element list defining temporal range. Format: ["YYYY-MM-DD", "YYYY-MM-DD"] or with time stamps. Selection is label-based and inclusive on both ends (xarray ``.sel(slice(...))`` semantics). For coarse time resolutions (monthly, 8-day composites) the period must include at least one label; a window that falls between labels produces an empty slice and raises ``ValueError``. For sub-daily sources specified with date-only strings, xarray treats the end date as ``T00:00:00`` — the final day's afternoon is excluded. Raises: KeyError: If target_id is not in target_gdf columns, or if coordinate names or variables are not found in the dataset. ValueError: If source_crs or target_crs are invalid CRS specifications. FileNotFoundError: If source_ds or target_gdf file paths don't exist. Note: This class performs extensive validation upon initialization, including checking for the existence of specified coordinates and variables, validating CRS definitions, and ensuring the target ID exists in the geometries. Examples: Initialize with local NetCDF file: .. code-block:: python import xarray as xr data = UserCatData( source_ds="temperature_data.nc", source_crs=4326, source_x_coord="longitude", source_y_coord="latitude", source_t_coord="time", source_var=["temperature"], target_gdf="watersheds.shp", target_crs=4326, target_id="huc12", source_time_period=["2020-01-01", "2020-12-31"] ) Initialize with in-memory dataset: .. code-block:: python ds = xr.open_dataset("climate.nc") data = UserCatData( source_ds=ds, source_crs="EPSG:4326", source_x_coord="lon", source_y_coord="lat", source_t_coord="time", source_var=["temp", "precip"], target_gdf=polygons_gdf, target_crs=3857, target_id="poly_id", source_time_period=["2019-01-01", "2021-12-31"] ) """ logger.info("Initializing UserCatData") logger.info(" - loading data") self.source_ds = _get_xr_dataset(ds=source_ds) self.target_id = target_id self.target_gdf = _read_shp_file(shp_file=target_gdf) # Validate target_id exists in target_gdf if self.target_id not in self.target_gdf.columns: logger.error("target_id %s not in target_gdf columns: %s", self.target_id, self.target_gdf.columns) raise KeyError( f"target_id '{self.target_id}' not found in target_gdf columns: {list(self.target_gdf.columns)}" ) # Validate source CRS if not _is_valid_crs(source_crs): raise ValueError( f"Invalid CRS specification: {source_crs!r}. " "Please provide a valid CRS (e.g., EPSG code, proj string, WKT, or pyproj.CRS object)." ) # Validate target CRS if not _is_valid_crs(target_crs): raise ValueError( f"Invalid target CRS specification: {target_crs!r}. " "Please provide a valid CRS (e.g., EPSG code, proj string, WKT, or pyproj.CRS object)." ) # Validate coordinate names exist in dataset if source_x_coord not in self.source_ds.coords: raise KeyError( f"X coordinate '{source_x_coord}' not found in dataset coordinates: {list(self.source_ds.coords.keys())}" # noqa: E501 ) if source_y_coord not in self.source_ds.coords: raise KeyError( f"Y coordinate '{source_y_coord}' not found in dataset coordinates: {list(self.source_ds.coords.keys())}" # noqa: E501 ) if source_t_coord not in self.source_ds.coords: raise KeyError( f"Time coordinate '{source_t_coord}' not found in dataset coordinates: {list(self.source_ds.coords.keys())}" # noqa: E501 ) # Validate and process source variables source_var_list = source_var if isinstance(source_var, list) else [source_var] # Check for empty variable list if not source_var_list: raise ValueError("source_var cannot be empty. Please provide at least one variable name.") # Validate each variable exists in dataset missing_vars = [] if missing_vars := [var for var in source_var_list if var not in self.source_ds.data_vars]: raise KeyError( f"Variables {missing_vars} not found in dataset data variables: {list(self.source_ds.data_vars.keys())}" ) self.source_var = source_var_list # Validate and process time period try: self.source_time_period = _process_period(source_time_period) except (ValueError, TypeError, pd.errors.ParserError) as e: raise ValueError( # noqa: B904 f"Invalid time period specification: {source_time_period}. " f"Please provide valid date strings (e.g., 'YYYY-MM-DD'). Error: {e}" ) # Set remaining attributes self.source_crs = source_crs self.source_x_coord = source_x_coord self.source_y_coord = source_y_coord self.source_t_coord = source_t_coord self.target_crs = target_crs self._gdf_bounds = _get_shp_bounds_w_buffer( gdf=self.target_gdf, ds=self.source_ds, crs=self.source_crs, lon=self.source_x_coord, lat=self.source_y_coord, ) logger.info(" - checking latitude bounds") is_intersect, is_degrees, is_0_360 = _check_for_intersection_nc( ds=self.source_ds, x_name=self.source_x_coord, y_name=self.source_y_coord, proj=self.source_crs, gdf=self.target_gdf, ) self.source_ds = _rotate_longitude_if_needed( self.source_ds, self.source_x_coord, should_rotate=bool((not is_intersect) & is_degrees & (is_0_360)), ) # Verify source_time_period actually selects at least one time label # before any expensive subset/aggregation work. _assert_period_intersects_time_axis( self.source_ds, self.source_t_coord, self.source_time_period ) # calculate toptobottom (order of latitude coords) self._ttb = _get_top_to_bottom(self.source_ds, self.source_y_coord) logger.info(" - getting gridded data subset") self._agg_subset_dict = build_subset( bounds=self._gdf_bounds, xname=self.source_x_coord, yname=self.source_y_coord, tname=self.source_t_coord, toptobottom=self._ttb, date_min=self.source_time_period[0], date_max=self.source_time_period[1], )
[docs] @classmethod def __repr__(cls) -> str: """Print class name.""" return f"Class is {cls.__name__}"
[docs] def get_target_crs(self) -> str | int | CRS: """Return the coordinate reference system (CRS) for the source data. This method provides the CRS used by the target geometries. Returns: str | int | CRS: The CRS associated with the target geometries. """ return self.target_crs
[docs] def get_source_subset(self, key: str) -> xr.DataArray: """Get a spatially and temporally subset of the source dataset. This method applies the pre-calculated spatial and temporal subset dictionary to the source dataset for the given variable. Args: key (str): Name of the xarray gridded data variable. Returns: xr.DataArray: A subsetted xarray DataArray of the original source gridded data. """ return self.source_ds[key].sel(**self._agg_subset_dict)
[docs] def get_feature_id(self) -> str: """Return target_id.""" return self.target_id
[docs] def get_vars(self) -> list[str]: """Return list of vars in data.""" return self.source_var
[docs] def get_class_type(self) -> str: """Get the type identifier for this data source class.""" return "UserCatData"
[docs] def prep_interp_data(self, key: str, poly_id: Union[str, int]) -> AggData: """Prep AggData from UserCatData. Args: key (str): Name of the xarray gridded data variable. poly_id (Union[str, int]): ID number of the geodataframe geometry to clip the gridded data to. Returns: AggData: An instance of the AggData class, ready for interpolation. """ # Open grid and clip to geodataframe and time window data_ss: xr.DataArray = _sel_with_time_diagnostic( self.source_ds[key], self._agg_subset_dict, t_coord=self.source_t_coord, period_start=self.source_time_period[0], ) # Select a feature and make sure it remains a geodataframe target_gdf = self.target_gdf[self.target_gdf[self.target_id] == poly_id] # Reproject the feature to grid crs and get a buffered bounding box bounds = _get_shp_bounds_w_buffer( gdf=target_gdf, ds=self.source_ds, crs=self.source_crs, lon=self.source_x_coord, lat=self.source_y_coord, ) # Clip grid to time window and line geometry bbox buffer ss_dict = build_subset( bounds=bounds, xname=self.source_x_coord, yname=self.source_y_coord, tname=self.source_t_coord, toptobottom=self._ttb, date_min=str(self.source_time_period[0]), date_max=str(self.source_time_period[1]), ) ds_ss = _sel_with_time_diagnostic( data_ss, ss_dict, t_coord=self.source_t_coord, period_start=self.source_time_period[0], ) cat_cr = self._create_climrcats(key=key, da=ds_ss) return AggData( variable=key, cat_cr=cat_cr, da=ds_ss, target_gdf=target_gdf, target_id=self.target_id, source_time_period=self.source_time_period, )
[docs] def prep_agg_data(self, key: str) -> AggData: """Prepare data for aggregation operations. This method subsets the source dataset based on the pre-calculated spatial and temporal bounds and prepares an AggData object. Args: key (str): The variable name to prepare for aggregation. Returns: An AggData instance ready for aggregation. """ logger.info("Agg Data preparation - beginning") data_ss: xr.DataArray = _sel_with_time_diagnostic( self.source_ds[key], self._agg_subset_dict, t_coord=self.source_t_coord, period_start=self.source_time_period[0], ) target_gdf = self.target_gdf cat_cr = self._create_climrcats(key=key, da=data_ss) logger.info(" - returning AggData") return AggData( variable=key, cat_cr=cat_cr, da=data_ss, target_gdf=target_gdf, target_id=self.target_id, source_time_period=self.source_time_period, )
def _create_climrcats(self: "UserCatData", key: str, da: xr.DataArray) -> CatClimRItem: """Create a CatClimRItem instance from the user-provided metadata. Args: key (str): The variable name. da (xr.DataArray): The DataArray for the variable, used to extract metadata. Returns: CatClimRItem: A dataclass instance containing the metadata. """ return CatClimRItem( # id=self.id, URL="", varname=key, long_name=str(self._get_ds_var_attrs(da, "long_name")), T_name=self.source_t_coord, units=str(self._get_ds_var_attrs(da, "units")), X_name=self.source_x_coord, Y_name=self.source_y_coord, proj=str(self.source_crs), resX=max(np.diff(da[self.source_x_coord].values)), resY=max(np.diff(da[self.source_y_coord].values)), toptobottom=str(_get_top_to_bottom(da, self.source_y_coord)), ) def _get_ds_var_attrs(self, da: xr.DataArray, attr: str) -> Any: # noqa: ANN401 """Return source_var attributes. Args: da (xr.DataArray): The DataArray to inspect. attr (str): The attribute name to retrieve. Returns: Any: The value of the attribute, or "None" if not found. """ try: return da.attrs.get(attr) except KeyError: return "None"
[docs] def prep_wght_data(self) -> WeightData: """Prepare data for weight generation calculations. This method subsets the source dataset and generates grid cell polygons required for calculating spatial intersection weights. Returns: WeightData: Data required for weight generation. """ logger.info("Weight Data preparation - beginning") data_ss_wght = _sel_with_time_diagnostic( self.source_ds, self._agg_subset_dict, t_coord=self.source_t_coord, period_start=self.source_time_period[0], ) logger.info(" - calculating grid-cell polygons") start = time.perf_counter() grid_poly = _get_cells_poly_fast( xr_a=data_ss_wght, x=self.source_x_coord, y=self.source_y_coord, crs_in=self.source_crs, ) end = time.perf_counter() logger.info(f"Generating grid-cell polygons finished in {round(end - start, 2)} second(s)") return WeightData(target_gdf=self.target_gdf, target_id=self.target_id, grid_cells=grid_poly)
class _NHGFStacBase(UserData): """Base class for NHGF STAC catalog data sources. Provides shared initialization and helper methods for all STAC-backed data source classes (Zarr and GeoTIFF). """ def __init__( self, *, source_collection, # noqa: ANN001 source_var: Union[str, list[str]], target_gdf: Union[str, Path, gpd.GeoDataFrame], target_id: str, ) -> None: """Initialize shared STAC data source attributes. Args: source_collection: STAC collection object for the dataset. source_var (Union[str, list[str]]): Variable name(s) to process. target_gdf (Union[str, Path, gpd.GeoDataFrame]): Target geometries. target_id (str): Column name in target_gdf containing unique identifiers. Raises: KeyError: If target_id is not in target_gdf columns. """ self.id = source_collection.id self.target_id = target_id self.target_gdf = target_gdf self._gdf = _read_shp_file(self.target_gdf) if self.target_id not in self._gdf.columns: logger.error(f"target_id {self.target_id} not in target_gdf columns: {self._gdf.columns}") raise KeyError(f"target_id {self.target_id} not in target_gdf columns: {self._gdf.columns}") self.source_var = source_var if isinstance(source_var, list) else [source_var] self.target_crs = self._gdf.crs def get_target_crs(self) -> str | int | CRS: """Return the CRS for the target geometries.""" return self.target_crs def get_feature_id(self) -> str: """Return target_id.""" return self.target_id def get_vars(self) -> list[str]: """Return list of vars in data.""" return self.source_var @staticmethod def _parse_cube_coords( source_collection, source_var: str # noqa: ANN001 ) -> dict[str, str | None]: """Extract coordinate names from STAC cube:variables and cube:dimensions. First looks up *source_var* in ``cube:variables`` to get the dimension names specific to that variable, then classifies each dimension using ``cube:dimensions`` axis/type metadata. This avoids ambiguity when a collection contains both primary and staggered grid dimensions (e.g. ``x`` vs ``x_stag`` in CONUS404). When ``cube:variables`` metadata is absent, falls back to scanning all ``cube:dimensions`` entries with a prioritized strategy: 1. **Exact canonical names** — ``"x"``, ``"y"``, ``"time"`` are preferred if they appear in ``cube:dimensions``. 2. **Axis/type matches** — the first dimension with ``axis="x"``, ``axis="y"``, or ``type="temporal"`` is used for any slot not already filled by step 1. This prevents staggered dimensions (e.g. ``x_stag``) from being selected when canonical names are available. .. note:: :meth:`_resolve_cube_coords` calls this method for every requested variable and warns if any variable resolves to different coordinate names. Collections with variables on different grids (e.g. staggered vs. unstaggered) should be queried one variable at a time. Returns: dict with keys ``"x"``, ``"y"``, ``"t"`` mapped to the dimension names found, or ``None`` when a dimension is absent. """ cube_dims = source_collection.extra_fields.get("cube:dimensions", {}) cube_vars = source_collection.extra_fields.get("cube:variables", {}) result: dict[str, str | None] = {"x": None, "y": None, "t": None} # Narrow the search to only the dimensions used by this variable var_info = cube_vars.get(source_var, {}) var_dims = var_info.get("dimensions") if var_dims: # Variable-specific dimensions — no ambiguity, first match wins for dim_name in var_dims: dim_info = cube_dims.get(dim_name, {}) axis = dim_info.get("axis") if axis == "x" and result["x"] is None: result["x"] = dim_name elif axis == "y" and result["y"] is None: result["y"] = dim_name if dim_info.get("type") == "temporal" and result["t"] is None: result["t"] = dim_name else: # No cube:variables — use prioritised fallback across all dims. # Pass 1: prefer canonical names (x, y, time) _canonical = {"x": "x", "y": "y", "time": "t"} for name, slot in _canonical.items(): if name in cube_dims: result[slot] = name # Pass 2: fill remaining slots from axis/type metadata for dim_name in cube_dims: dim_info = cube_dims.get(dim_name, {}) axis = dim_info.get("axis") if axis == "x" and result["x"] is None: result["x"] = dim_name elif axis == "y" and result["y"] is None: result["y"] = dim_name if dim_info.get("type") == "temporal" and result["t"] is None: result["t"] = dim_name return result @classmethod def _resolve_cube_coords( cls, source_collection, # noqa: ANN001 source_vars: list[str], ) -> dict[str, str | None]: """Resolve coordinate names, validating consistency across variables. Calls :meth:`_parse_cube_coords` for each variable in *source_vars*. If any variable resolves to different coordinate names a warning is logged and the coordinates of the **first** variable are used. Returns: dict with keys ``"x"``, ``"y"``, ``"t"``. """ first = cls._parse_cube_coords(source_collection, source_vars[0]) for var in source_vars[1:]: other = cls._parse_cube_coords(source_collection, var) if other != first: logger.warning( "Variable %r resolves to different coordinates %s " "than %r %s — using coordinates from %r. " "Consider querying variables on different grids separately.", var, other, source_vars[0], first, source_vars[0], ) break return first def _get_ds_var_attrs(self, da: xr.DataArray, attr: str) -> Any: # noqa: ANN401 """Return source_var attributes. Args: da (xr.DataArray): Target DataArray to pull attributes from. attr (str): Name of the attribute to return. Returns: The value of the attribute, or "None" if not found. """ try: return da.attrs.get(attr) except KeyError: return "None"
[docs] class NHGFStacZarrData(_NHGFStacBase): """Interface for Zarr-backed NHGF STAC catalog datasets. This class provides access to Zarr datasets through the NHGF STAC catalog, with automatic spatiotemporal filtering and metadata handling. Attributes: source_collection: STAC collection identifier for the dataset. source_var: Variable name(s) to access from the collection. target_gdf: GeoDataFrame containing target geometries. target_id: Column name for unique identifiers in target geometries. source_time_period: Time period for temporal filtering. Examples: .. code-block:: python from gdptools import NHGFStacZarrData from gdptools.helpers import get_stac_collection collection = get_stac_collection("conus404_daily") data = NHGFStacZarrData( source_collection=collection, source_var=["PWAT"], target_gdf="watersheds.shp", target_id="huc12", source_time_period=["2020-01-01", "2020-01-31"], ) """
[docs] def __init__( self, *, source_collection, # noqa: ANN001 source_var: Union[str, list[str]], target_gdf: Union[str, Path, gpd.GeoDataFrame], target_id: str, source_time_period: list[Union[str, pd.Timestamp, datetime.datetime] | None], ) -> None: """Initialize NHGFStacZarrData for Zarr-backed STAC collections. Args: source_collection: STAC collection object with a ``zarr-s3-osn`` asset. source_var (Union[str, list[str]]): Variable name(s) to aggregate. target_gdf (Union[str, Path, gpd.GeoDataFrame]): Target geometries. target_id (str): Column name in target_gdf containing unique identifiers. source_time_period (list[str]): Two-element list ``[start, end]`` defining the temporal subset (``'YYYY-MM-DD'`` or with time component). Selection is label-based and inclusive on both ends (xarray ``.sel(slice(...))`` semantics). For coarse time resolutions (monthly, 8-day composites) the period must include at least one label; a window that falls between labels produces an empty slice and raises ``ValueError``. For sub-daily sources specified with date-only strings, xarray treats the end date as ``T00:00:00`` — the final day's afternoon is excluded. Raises: KeyError: If target_id is not in target_gdf columns. """ logger.info("Initializing NHGFStacZarrData") logger.info(" - loading data") super().__init__( source_collection=source_collection, source_var=source_var, target_gdf=target_gdf, target_id=target_id, ) # Normalize the period before opening the remote Zarr so malformed or # reversed inputs fail fast without any network access. self.source_time_period = _process_period(source_time_period) self.asset = source_collection.assets["zarr-s3-osn"] self.source_ds = xr.open_zarr( self.asset.href, storage_options=self.asset.extra_fields["xarray:storage_options"], ) self.source_crs = self.source_ds.crs.attrs["crs_wkt"] if type(CRS.from_string(self.source_crs)) is not CRS: logger.error("Projection of the gridded dataset could not be identified") cube_coords = self._resolve_cube_coords(source_collection, self.source_var) self.source_x_coord = cube_coords["x"] or "x" self.source_y_coord = cube_coords["y"] or "y" self.source_t_coord = cube_coords["t"] or "time" logger.debug( f" - coordinate names: x={self.source_x_coord}, " f"y={self.source_y_coord}, t={self.source_t_coord}" ) self._gdf_bounds = _get_shp_bounds_w_buffer( gdf=self._gdf, ds=self.source_ds, crs=self.source_crs, lon=self.source_x_coord, lat=self.source_y_coord, ) logger.info(" - checking latitude bounds") is_intersect, is_degrees, is_0_360 = _check_for_intersection_nc( ds=self.source_ds, x_name=self.source_x_coord, y_name=self.source_y_coord, proj=self.source_crs, gdf=self._gdf, ) self.source_ds = _rotate_longitude_if_needed( self.source_ds, self.source_x_coord, should_rotate=bool((not is_intersect) & is_degrees & (is_0_360)), ) # Verify source_time_period actually selects at least one time label. # Eager check on the (small) time coordinate of the opened Zarr. _assert_period_intersects_time_axis( self.source_ds, self.source_t_coord, self.source_time_period ) # calculate toptobottom (order of latitude coords) self._ttb = _get_top_to_bottom(self.source_ds, self.source_y_coord) logger.info(" - getting gridded data subset") self._weight_subset_dict = build_subset( bounds=self._gdf_bounds, xname=self.source_x_coord, yname=self.source_y_coord, tname=self.source_t_coord, toptobottom=self._ttb, date_min=self.source_time_period[0], ) self._agg_subset_dict = build_subset( bounds=self._gdf_bounds, xname=self.source_x_coord, yname=self.source_y_coord, tname=self.source_t_coord, toptobottom=self._ttb, date_min=self.source_time_period[0], date_max=self.source_time_period[1], )
[docs] @classmethod def __repr__(cls) -> str: """Print class name.""" return f"Class is {cls.__name__}"
[docs] def get_source_subset(self, key: str) -> xr.DataArray: """Get a subset of the STAC data source for a specific variable. Args: key (str): Name of the variable to subset. Returns: xr.DataArray: Subsetted dataarray. """ return self.source_ds[key].sel(**self._agg_subset_dict)
[docs] def get_class_type(self) -> str: """Return the type of the data class. Returns ``"NHGFStacData"`` for backward compatibility. """ return "NHGFStacData"
[docs] def prep_interp_data(self, key: str, poly_id: Union[str, int]) -> AggData: """Prep AggData from NHGFStacZarrData. Args: key (str): Name of the xarray gridded data variable. poly_id (Union[str, int]): ID number of the geodataframe geometry to clip the gridded data to. Returns: AggData: An instance of the AggData class, ready for interpolation. """ # Open grid and clip to geodataframe and time window data_ss = self.get_source_subset(key) cat_cr = self._create_climrcats(key=key, da=data_ss) # Select a feature and make sure it remains a geodataframe target_gdf = self._gdf[self._gdf[self.target_id] == poly_id] return AggData( variable=key, cat_cr=cat_cr, da=data_ss, target_gdf=target_gdf, target_id=self.target_id, source_time_period=self.source_time_period, )
[docs] def prep_agg_data(self, key: str) -> AggData: """Prep AggData from NHGFStacZarrData.""" logger.info("Agg Data preparation - beginning") data_ss: xr.DataArray = self.get_source_subset(key) cat_cr = self._create_climrcats(key=key, da=data_ss) target_gdf = self._gdf # If the time dimension has only one step: try: data_ss.coords.get(self.source_t_coord).all() source_time_period = self.source_time_period except Exception: source_time_period = ["None", "None"] logger.info(" - returning AggData") return AggData( variable=key, cat_cr=cat_cr, da=data_ss, target_gdf=target_gdf, target_id=self.target_id, source_time_period=source_time_period, )
def _create_climrcats(self: "NHGFStacZarrData", key: str, da: xr.DataArray) -> CatClimRItem: """Create a CatClimRItem instance from STAC metadata. Args: key (str): The variable name. da (xr.DataArray): The DataArray for the variable, used to extract metadata. Returns: CatClimRItem: A dataclass instance containing the metadata. """ return CatClimRItem( URL=self.asset.href, varname=key, long_name=str(self._get_ds_var_attrs(da, "description")), T_name=self.source_t_coord, units=str(self._get_ds_var_attrs(da, "units")), X_name=self.source_x_coord, Y_name=self.source_y_coord, proj=str(self.source_crs), resX=max(np.diff(da[self.source_x_coord].values)), resY=max(np.diff(da[self.source_y_coord].values)), toptobottom=str(_get_top_to_bottom(da, self.source_y_coord)), )
[docs] def prep_wght_data(self) -> WeightData: """Prepare and return WeightData for weight generation.""" logger.info("Weight Data preparation - beginning") data_ss_wght = self.source_ds.sel(**self._weight_subset_dict) # type: ignore logger.info(" - calculating grid-cell polygons") start = time.perf_counter() grid_poly = _get_cells_poly_fast( xr_a=data_ss_wght, x=self.source_x_coord, y=self.source_y_coord, crs_in=self.source_crs, ) end = time.perf_counter() logger.info(f"Generating grid-cell polygons finished in {round(end - start, 2)} second(s)") return WeightData(target_gdf=self._gdf, target_id=self.target_id, grid_cells=grid_poly)
[docs] class NHGFStacTiffData(_NHGFStacBase): """Interface for GeoTIFF-backed NHGF STAC catalog datasets (e.g., NLCD). This class provides access to GeoTIFF datasets in the NHGF STAC catalog, where assets are stored as individual GeoTIFF files per time step (one STAC item per year). It handles remote S3 access, band selection, and spatial subsetting. Currently supports single-item (single year) access. Multi-item stacking along a time axis is planned for a future release. Attributes: source_collection: STAC collection for the dataset. source_var: Variable name(s) to access. target_gdf: GeoDataFrame containing target geometries. target_id: Column name for unique identifiers in target geometries. band: Selected band number (1-indexed). Examples: .. code-block:: python from gdptools import NHGFStacTiffData from gdptools.helpers import get_stac_collection collection = get_stac_collection("nlcd-LndCov") data = NHGFStacTiffData( source_collection=collection, source_var=["LndCov"], target_gdf="watersheds.shp", target_id="huc12", source_time_period=["2021-01-01", "2021-12-31"], ) """
[docs] def __init__( self, *, source_collection, # noqa: ANN001 source_var: Union[str, list[str]], target_gdf: Union[str, Path, gpd.GeoDataFrame], target_id: str, source_time_period: list[Union[str, pd.Timestamp, datetime.datetime] | None] | None = None, band: int = 1, bname: str = "band", ) -> None: """Initialize NHGFStacTiffData for GeoTIFF-backed STAC collections. Args: source_collection: STAC collection whose items have GeoTIFF assets. source_var (Union[str, list[str]]): Variable name(s) for the raster data. target_gdf (Union[str, Path, gpd.GeoDataFrame]): Target geometries. target_id (str): Column name in target_gdf containing unique identifiers. source_time_period: Optional two-element list ``[start, end]`` to select which STAC item to load by datetime. If ``None``, the first item is used. Unlike Zarr-backed STAC, TIFF-backed STAC does not slice along a time axis — the period only narrows the candidate STAC items. Reversed ranges (``start > end``) and malformed date strings are still rejected at construction by ``_process_period``. band (int): Band number to select (1-indexed). Defaults to 1. bname (str): Name of the band dimension. Defaults to ``"band"``. Raises: KeyError: If target_id is not in target_gdf columns. ValueError: If no STAC items found or no matching item for the time period. """ import rasterio import rioxarray # noqa: F401 from rasterio.windows import from_bounds as window_from_bounds logger.info("Initializing NHGFStacTiffData") logger.info(" - loading data") # Normalize the period before super().__init__ + STAC item discovery so # malformed or reversed inputs fail fast. Preserve the ['None', 'None'] # sentinel for the "no period specified" case. if source_time_period is None: normalized_period = ["None", "None"] else: normalized_period = _process_period(source_time_period) super().__init__( source_collection=source_collection, source_var=source_var, target_gdf=target_gdf, target_id=target_id, ) self.source_time_period = normalized_period self.band = band self.bname = bname self.source_t_coord = None cube_coords = self._resolve_cube_coords(source_collection, self.source_var) self.source_x_coord = cube_coords["x"] or "x" self.source_y_coord = cube_coords["y"] or "y" logger.debug( f" - coordinate names: x={self.source_x_coord}, y={self.source_y_coord}" ) # Discover items and select one based on time period items = list(source_collection.get_items()) if not items: raise ValueError( f"No items found in STAC collection '{source_collection.id}'. " "Ensure the collection contains GeoTIFF items." ) self._stac_item = self._select_item(items, self.source_time_period) # Find the GeoTIFF asset self._asset_key, self._asset = self._find_tiff_asset(self._stac_item) self._asset_href = self._asset.href # Build a URL that GDAL can open via its /vsicurl/ driver. # S3 paths are converted to HTTPS using the endpoint from # storage_options; hrefs that are already HTTP(S) are used as-is. # This enables efficient range-request reads for Cloud-Optimized # GeoTIFFs — only the tiles covering the target area are fetched, # avoiding loading the full raster into memory. storage_options = self._asset.extra_fields.get("xarray:storage_options", {}) self._https_url = _stac_href_to_url(self._asset_href, storage_options) logger.info(f" - opening GeoTIFF (clipped to target extent): {self._https_url}") with rasterio.open(self._https_url) as src: source_crs_rio = src.crs # Compute buffered target bounds in source CRS for windowed # reading. Only the pixels covering the target polygons (plus # a 5 % buffer) are fetched from the remote COG. target_bounds = self._gdf.to_crs(source_crs_rio).total_bounds dx = abs(target_bounds[2] - target_bounds[0]) * 0.05 dy = abs(target_bounds[3] - target_bounds[1]) * 0.05 clip_bounds = ( target_bounds[0] - dx, target_bounds[1] - dy, target_bounds[2] + dx, target_bounds[3] + dy, ) # Compute pixel window and clamp to raster extent window = window_from_bounds(*clip_bounds, transform=src.transform) window = window.round_offsets().round_lengths() full_window = rasterio.windows.Window(0, 0, src.width, src.height) try: window = window.intersection(full_window) except rasterio.errors.WindowError as exc: raise ValueError( f"Target polygons do not overlap the raster extent. " f"Target bounds (source CRS): {tuple(target_bounds)}, " f"raster bounds: {src.bounds}" ) from exc if window.width == 0 or window.height == 0: raise ValueError( f"Target polygons do not overlap the raster extent. " f"Target bounds (source CRS): {tuple(target_bounds)}, " f"raster bounds: {src.bounds}" ) # Read only the windowed subset data = src.read(window=window) win_transform = src.window_transform(window) # Build coordinate arrays (pixel centers) nrows, ncols = data.shape[1], data.shape[2] x_coords = win_transform.c + (np.arange(ncols) + 0.5) * win_transform.a y_coords = win_transform.f + (np.arange(nrows) + 0.5) * win_transform.e band_coords = np.arange(1, data.shape[0] + 1) self.source_ds = xr.DataArray( data, dims=[self.bname, self.source_y_coord, self.source_x_coord], coords={ self.bname: band_coords, self.source_y_coord: y_coords, self.source_x_coord: x_coords, }, ) self.source_ds = self.source_ds.rio.write_crs(source_crs_rio) self.source_ds = self.source_ds.rio.write_transform(win_transform) # Extract CRS self.source_crs = self.source_ds.rio.crs if self.source_crs is None: raise ValueError("Could not determine CRS from the GeoTIFF.") self.source_crs = self.source_crs.to_epsg() or self.source_crs.to_wkt() # Spatial validation self._toptobottom = _get_top_to_bottom(data=self.source_ds, y_coord=self.source_y_coord) self.varname = source_var if isinstance(source_var, str) else source_var[0]
@staticmethod def _select_item(items: list, source_time_period: list | None) -> "pystac.Item": """Select a STAC item based on the time period. If ``source_time_period`` is provided, selects the item whose datetime falls within the range. Otherwise returns the first item. Expects ``source_time_period`` to already be normalized by ``_process_period`` — elements are strings or ``None``. STAC item datetimes are tz-aware (UTC) per the STAC spec, so user-provided naive timestamps are coerced to UTC before comparison. """ if source_time_period is None or source_time_period == ["None", "None"]: return items[0] start_raw, end_raw = source_time_period[0], source_time_period[1] start = pd.Timestamp(start_raw) if start_raw is not None else None end = pd.Timestamp(end_raw) if end_raw is not None else None if start is not None and start.tzinfo is None: start = start.tz_localize("UTC") if end is not None and end.tzinfo is None: end = end.tz_localize("UTC") for item in items: dt_str = item.properties.get("datetime") or item.properties.get("start_datetime") if not dt_str: continue item_dt = pd.Timestamp(dt_str) if item_dt.tzinfo is None: item_dt = item_dt.tz_localize("UTC") if (start is None or start <= item_dt) and (end is None or item_dt <= end): return item available_dates = [ item.properties.get("datetime") or item.properties.get("start_datetime") for item in items if item.properties.get("datetime") or item.properties.get("start_datetime") ] raise ValueError( f"No STAC item found matching time period {source_time_period}. " f"Available dates: {available_dates[:10]}{'...' if len(available_dates) > 10 else ''}" ) @staticmethod def _find_tiff_asset(item: "pystac.Item") -> tuple[str, "pystac.Asset"]: """Find the GeoTIFF asset in a STAC item. Looks for asset key ``tif-s3-osn`` first, then falls back to any asset whose media type contains ``image/tiff``. """ # Preferred key if "tif-s3-osn" in item.assets: return "tif-s3-osn", item.assets["tif-s3-osn"] # Fallback: search by media type for key, asset in item.assets.items(): media_type = getattr(asset, "media_type", "") or "" extra = getattr(asset, "extra_fields", {}) or {} asset_type_str = extra.get("type", "") if "image/tiff" in media_type.lower() or "image/tiff" in asset_type_str.lower(): return key, asset raise ValueError( f"No GeoTIFF asset found in STAC item '{item.id}'. " f"Available assets: {list(item.assets.keys())}" )
[docs] @classmethod def __repr__(cls) -> str: """Print class name.""" return f"Class is {cls.__name__}"
[docs] def get_source_subset(self, key: str) -> xr.DataArray: """Get a spatially subset of the source raster data. Args: key (str): Variable name (used for interface consistency). Returns: xr.DataArray: Spatially subsetted DataArray. """ bb_feature = _get_shp_bounds_w_buffer( gdf=self._gdf, ds=self.source_ds, crs=self.source_crs, lon=self.source_x_coord, lat=self.source_y_coord, ) subset_dict = build_subset_tiff( bounds=bb_feature, xname=self.source_x_coord, yname=self.source_y_coord, toptobottom=self._toptobottom, bname=self.bname, band=self.band, ) return self.source_ds.sel(**subset_dict)
[docs] def get_class_type(self) -> str: """Return the type of the data class.""" return "NHGFStacTiffData"
[docs] def prep_interp_data(self, key: str, poly_id: Union[str, int]) -> AggData: """Prep AggData from NHGFStacTiffData for interpolation. Args: key (str): Name of the variable. poly_id (Union[str, int]): ID of the target geometry. Returns: AggData: An instance ready for interpolation. """ target_gdf = self._gdf[self._gdf[self.target_id] == poly_id] bb_feature = _get_shp_bounds_w_buffer( gdf=target_gdf, ds=self.source_ds, crs=self.source_crs, lon=self.source_x_coord, lat=self.source_y_coord, ) subset_dict = build_subset_tiff( bounds=bb_feature, xname=self.source_x_coord, yname=self.source_y_coord, toptobottom=self._toptobottom, bname=self.bname, band=self.band, ) data_ss: xr.DataArray = self.source_ds.sel(**subset_dict) cat_cr = self._create_climrcats(key=key, da=data_ss) return AggData( variable=key, cat_cr=cat_cr, da=data_ss, target_gdf=target_gdf, target_id=self.target_id, source_time_period=["None", "None"], )
[docs] def prep_agg_data(self, key: str) -> AggData: """Prep AggData from NHGFStacTiffData for aggregation. Args: key (str): Name of the variable. Returns: AggData: An instance ready for aggregation. """ bb_feature = _get_shp_bounds_w_buffer( gdf=self._gdf, ds=self.source_ds, crs=self.source_crs, lon=self.source_x_coord, lat=self.source_y_coord, ) subset_dict = build_subset_tiff( bounds=bb_feature, xname=self.source_x_coord, yname=self.source_y_coord, toptobottom=self._toptobottom, bname=self.bname, band=self.band, ) data_ss: xr.DataArray = self.source_ds.sel(**subset_dict) if data_ss.size == 0: raise ValueError( "Sub-setting the raster resulted in no values. " f"Check the specified source_crs value: {self.source_crs}" ) cat_cr = self._create_climrcats(key=key, da=data_ss) return AggData( variable=key, cat_cr=cat_cr, da=data_ss, target_gdf=self._gdf.copy(), target_id=self.target_id, source_time_period=["None", "None"], )
[docs] def prep_wght_data(self) -> WeightData: """Prepare weight data. Not implemented for TIFF-based STAC data. Zonal statistics via ``ZonalGen`` or ``WeightedZonalGen`` is the primary analysis path. """ pass
def _create_climrcats(self: "NHGFStacTiffData", key: str, da: xr.DataArray) -> CatClimRItem: """Create a CatClimRItem instance from STAC/raster metadata. Args: key (str): The variable name. da (xr.DataArray): The DataArray, used to extract metadata. Returns: CatClimRItem: Metadata container. """ return CatClimRItem( URL=self._asset_href, varname=key, long_name=str(self._get_ds_var_attrs(da, "description")), units=str(self._get_ds_var_attrs(da, "units")), X_name=self.source_x_coord, Y_name=self.source_y_coord, proj=str(self.source_crs), resX=max(np.diff(da[self.source_x_coord].values)), resY=max(np.diff(da[self.source_y_coord].values)), toptobottom=str(self._toptobottom), )
def _stac_href_to_url(href: str, storage_options: dict) -> str: """Convert a STAC asset href to a URL that GDAL can open. * ``s3://bucket/key`` hrefs are converted to HTTPS using the endpoint in *storage_options* (or the USGS OSN default). * ``http://`` / ``https://`` hrefs are returned unchanged. * Bare paths (no scheme) are prefixed with the endpoint URL. """ if href.startswith(("http://", "https://")): return href endpoint_url = ( storage_options.get("client_kwargs", {}).get("endpoint_url", "") or "https://usgs.osn.mghpcc.org/" ) s3_path = href.removeprefix("s3://") return f"{endpoint_url.rstrip('/')}/{s3_path}" def _detect_stac_asset_type(collection) -> str: # noqa: ANN001 """Detect whether a STAC collection is Zarr-backed or GeoTIFF-backed. Args: collection: A pystac Collection object. Returns: ``"zarr"`` if the collection has Zarr assets, ``"tiff"`` if items have GeoTIFF assets. Raises: ValueError: If the asset type cannot be determined. """ # Check for collection-level Zarr assets if hasattr(collection, "assets") and collection.assets: for key, asset in collection.assets.items(): href = getattr(asset, "href", "") media_type = getattr(asset, "media_type", "") or "" if ( "zarr" in key.lower() or href.endswith(".zarr") or href.endswith(".zarr/") or "zarr" in media_type.lower() ): return "zarr" # Check for item-level GeoTIFF assets try: items = list(collection.get_items()) if items: for key, asset in items[0].assets.items(): media_type = getattr(asset, "media_type", "") or "" extra = getattr(asset, "extra_fields", {}) or {} asset_type_str = extra.get("type", "") if ( "tif" in key.lower() or "image/tiff" in media_type.lower() or "image/tiff" in asset_type_str.lower() ): return "tiff" except Exception: logger.debug("Failed to retrieve items from collection '%s'", collection.id, exc_info=True) raise ValueError( f"Cannot determine asset type for collection '{collection.id}'. " "No Zarr assets or GeoTIFF items found. " "Use NHGFStacZarrData or NHGFStacTiffData directly, or pass asset_type='zarr'|'tiff'." )
[docs] class NHGFStacData: """Factory for NHGF STAC catalog datasets. Auto-detects whether the STAC collection is Zarr-backed or GeoTIFF-backed and returns the appropriate concrete class (``NHGFStacZarrData`` or ``NHGFStacTiffData``). This preserves backward compatibility: existing code using ``NHGFStacData(source_collection=zarr_collection, ...)`` continues to work and receives an ``NHGFStacZarrData`` instance. Examples: Zarr collection (auto-detected): .. code-block:: python from gdptools import NHGFStacData from gdptools.helpers import get_stac_collection collection = get_stac_collection("conus404_daily") data = NHGFStacData( source_collection=collection, source_var=["PWAT"], target_gdf="watersheds.shp", target_id="huc12", source_time_period=["2020-01-01", "2020-01-31"], ) # data is an NHGFStacZarrData instance GeoTIFF collection (auto-detected): .. code-block:: python collection = get_stac_collection("nlcd-LndCov") data = NHGFStacData( source_collection=collection, source_var=["LndCov"], target_gdf="watersheds.shp", target_id="huc12", source_time_period=["2021-01-01", "2021-12-31"], ) # data is an NHGFStacTiffData instance """ _deprecation_map = { # noqa: RUF012 "collection": "source_collection", "var": "source_var", "f_feature": "target_gdf", "id_feature": "target_id", "period": "source_time_period", }
[docs] @deprecate_kwargs(_deprecation_map, removed_in="1.0.0") def __new__( cls, *, source_collection, # noqa: ANN001 source_var: Union[str, list[str]], target_gdf: Union[str, Path, gpd.GeoDataFrame], target_id: str, source_time_period: list[Union[str, pd.Timestamp, datetime.datetime] | None] | None = None, asset_type: str | None = None, band: int = 1, bname: str = "band", **kwargs, # noqa: ANN003 ) -> "NHGFStacZarrData | NHGFStacTiffData": """Create an NHGF STAC data source with auto-detected format. Args: source_collection: STAC collection object. source_var: Variable name(s) to process. target_gdf: Target geometries. target_id: Column name for unique identifiers. source_time_period: Time period ``[start, end]``. Required for Zarr; optional for GeoTIFF (selects which item to load). For Zarr-backed STAC, selection is label-based and inclusive on both ends (xarray ``.sel(slice(...))`` semantics). For coarse time resolutions (monthly, 8-day composites) the period must include at least one label; a window that falls between labels produces an empty slice and raises ``ValueError``. For sub-daily sources specified with date-only strings, xarray treats the end date as ``T00:00:00`` — the final day's afternoon is excluded. For TIFF-backed STAC the period only narrows the candidate STAC items. asset_type: Optional override — ``"zarr"`` or ``"tiff"``. If ``None``, auto-detected from the collection. band: Band number for GeoTIFF (1-indexed). Ignored for Zarr. bname: Band dimension name for GeoTIFF. Ignored for Zarr. **kwargs: Additional keyword arguments for backward compatibility with deprecated parameter names. Returns: NHGFStacZarrData or NHGFStacTiffData instance. """ if asset_type is None: asset_type = _detect_stac_asset_type(source_collection) if asset_type == "zarr": if source_time_period is None: raise ValueError("source_time_period is required for Zarr-backed STAC collections.") return NHGFStacZarrData( source_collection=source_collection, source_var=source_var, target_gdf=target_gdf, target_id=target_id, source_time_period=source_time_period, ) elif asset_type == "tiff": return NHGFStacTiffData( source_collection=source_collection, source_var=source_var, target_gdf=target_gdf, target_id=target_id, source_time_period=source_time_period, band=band, bname=bname, ) else: raise ValueError( f"Unknown asset_type '{asset_type}'. Must be 'zarr' or 'tiff'." )
[docs] class UserTiffData(UserData): """Handler for GeoTIFF and other raster data sources. This class is optimized for working with raster data sources such as GeoTIFF files, providing specialized functionality for zonal statistics and spatial analysis operations. It handles single and multi-band rasters with automatic band selection and coordinate system validation. The class is particularly useful for processing elevation data, land cover classifications, and other raster datasets that require zonal statistics calculations over vector geometries. Attributes: source_ds: The source raster data as xarray DataArray or Dataset. target_gdf: GeoDataFrame containing target geometries for zonal operations. target_id: Column name for unique identifiers in target geometries. band: Selected band number for multi-band rasters. source_var: Variable name assigned to the raster data. Examples: Basic elevation processing: .. code-block:: python from gdptools.data.user_data import UserTiffData # Process elevation data data = UserTiffData( source_ds="elevation.tif", source_crs=4326, source_x_coord="x", source_y_coord="y", target_gdf="watersheds.shp", target_id="huc12" ) # Prepare for zonal statistics weight_data = data.prep_wght_data() Multi-band raster processing: .. code-block:: python # Select specific band from multi-band raster data = UserTiffData( source_ds="landcover.tif", source_crs=3857, source_x_coord="x", source_y_coord="y", target_gdf=polygons_gdf, target_id="poly_id", band=3, # Select band 3 source_var="landcover_class" ) In-memory raster data: .. code-block:: python import rioxarray as rxr raster = rxr.open_rasterio("slope.tif") data = UserTiffData( source_ds=raster, source_crs=raster.rio.crs, source_x_coord="x", source_y_coord="y", target_gdf="regions.gpkg", target_id="region_id" ) Notes: This class automatically handles band selection and coordinate system validation for raster data. It's optimized for zonal statistics workflows and integrates seamlessly with the ZonalGen classes. """ _deprecation_map = { # noqa: RUF012 "ds": "source_ds", "proj_ds": "source_crs", "x_coord": "source_x_coord", "y_coord": "source_y_coord", "t_coord": "source_t_coord", "var": "source_var", "f_feature": "target_gdf", "proj_feature": "target_crs", "id_feature": "target_id", "period": "source_time_period", }
[docs] @deprecate_kwargs(_deprecation_map, removed_in="1.0.0") def __init__( self, source_ds: Union[str, xr.DataArray, xr.Dataset], source_crs: str | int | CRS, source_x_coord: str, source_y_coord: str, target_gdf: Union[str, Path, gpd.GeoDataFrame], target_id: str, bname: str = "band", band: int = 1, source_var: str = "tiff", ) -> None: """Initialize UserTiffData for raster data processing. Args: source_ds: Raster data source as a file path, xarray DataArray, or Dataset. source_crs: Coordinate reference system of the raster data. source_x_coord: Name of the x-coordinate dimension in the raster. source_y_coord: Name of the y-coordinate dimension in the raster. target_gdf: Target geometries as a GeoDataFrame or file path. target_id: Column name in `target_gdf` for unique identifiers. bname: Name of the band dimension in multi-band rasters. Defaults to "band". band: Band number to select from a multi-band raster (1-indexed). Defaults to 1. source_var: A name to assign to the raster data variable. Defaults to "tiff". Raises: FileNotFoundError: If `source_ds` file path does not exist. KeyError: If `target_id` is not found in `target_gdf` columns. ValueError: If `source_crs` is invalid or the band number is out of range. """ self.varname = source_var # Need in zonal_engines to convert xarray dataarray to dataset self.source_x_coord = source_x_coord self.source_y_coord = source_y_coord self.bname = bname self.band = band self.source_ds = _get_rxr_dataset(source_ds) if not _is_valid_crs(source_crs): raise ValueError( f"Invalid CRS specification: {source_crs!r}. " "Please provide a valid CRS (e.g., EPSG code, proj string, WKT, or pyproj.CRS object)." ) self.source_crs = source_crs self.target_gdf = _read_shp_file(shp_file=target_gdf) self.target_id = target_id self.target_gdf = self.target_gdf.sort_values(self.target_id).dissolve(by=self.target_id, as_index=False) self.target_gdf.reset_index() self.target_crs = self.target_gdf.crs.to_epsg() self._check_xname() self._check_yname() self._check_band() self._check_crs() self._toptobottom = _get_top_to_bottom(data=self.source_ds, y_coord=self.source_y_coord)
[docs] def get_target_crs(self) -> CRS: """Get the coordinate reference system of the target geometries. Returns: CRS: The coordinate reference system of the target vector data. """ return self.target_crs
[docs] def get_source_subset(self, key: str) -> xr.DataArray: """Get a spatially subset of the source raster data. This method subsets the source raster based on the buffered bounding box of the target geometries. The `key` argument is not used for this class but is required for interface consistency. Args: key (str): A placeholder argument for interface consistency. Returns: xr.DataArray: A spatially subsetted xarray DataArray. """ bb_feature = _get_shp_bounds_w_buffer( gdf=self.target_gdf, ds=self.source_ds, crs=self.source_crs, lon=self.source_x_coord, lat=self.source_y_coord, ) subset_dict = build_subset_tiff( bounds=bb_feature, xname=self.source_x_coord, yname=self.source_y_coord, toptobottom=self._toptobottom, bname=self.bname, band=self.band, ) return self.source_ds.sel(**subset_dict) # type: ignore
[docs] def get_vars(self) -> list[str]: """Get the list of available variables. For `UserTiffData`, this is typically a single variable name assigned during initialization. Returns: list[str]: A list containing the single variable name. """ return [self.source_ds] if isinstance(self.source_ds, str) else [self.varname]
[docs] def get_feature_id(self) -> str: """Get the identifier column name for target geometries.""" return self.target_id
[docs] def prep_wght_data(self) -> WeightData: """Prepare data for weight generation. Notes: This method is not yet implemented for `UserTiffData`. Zonal statistics for rasters are handled by `prep_agg_data`. """ pass
[docs] def get_class_type(self) -> str: """Get the type identifier for this data source class.""" return "UserTiffData"
[docs] def prep_interp_data(self, key: str, poly_id: int) -> AggData: """Prepare data for interpolation operations. This method subsets the source raster data to the bounding box of a specific target geometry and prepares an `AggData` object for interpolation. Args: key (str): The variable name to prepare for interpolation. poly_id (int): The identifier of the target geometry to use for subsetting. Returns: AggData: An instance ready for interpolation. """ # Select a feature and make sure it remains a geodataframe target_gdf = self.target_gdf[self.target_gdf[self.target_id] == poly_id] bb_feature = _get_shp_bounds_w_buffer( gdf=target_gdf, ds=self.source_ds, crs=self.source_crs, lon=self.source_x_coord, lat=self.source_y_coord, ) subset_dict = build_subset_tiff( bounds=bb_feature, xname=self.source_x_coord, yname=self.source_y_coord, toptobottom=self._toptobottom, bname=self.bname, band=self.band, ) data_ss: xr.DataArray = self.source_ds.sel(**subset_dict) # type: ignore cat_cr = self._create_climrcats(key=key, da=data_ss) return AggData( variable=key, cat_cr=cat_cr, da=data_ss, target_gdf=target_gdf, target_id=self.target_id, source_time_period=["None", "None"], )
[docs] def prep_agg_data(self, key: str) -> AggData: """Prepare data for aggregation or zonal statistics. This method subsets the source raster data to the buffered bounding box of the target geometries and prepares an `AggData` object. Args: key (str): The variable name to prepare for aggregation. Returns: AggData: An instance ready for aggregation. Raises: ValueError: If subsetting the raster results in an empty dataset, which can indicate a CRS mismatch or no spatial overlap. """ bb_feature = _get_shp_bounds_w_buffer( gdf=self.target_gdf, ds=self.source_ds, crs=self.source_crs, lon=self.source_x_coord, lat=self.source_y_coord, ) subset_dict = build_subset_tiff( bounds=bb_feature, xname=self.source_x_coord, yname=self.source_y_coord, toptobottom=self._toptobottom, bname=self.bname, band=self.band, ) data_ss: xr.DataArray = self.source_ds.sel(**subset_dict) # type: ignore if data_ss.size == 0: raise ValueError( "Sub-setting the raster resulted in no values", f"check the specified source_crs value: {self.source_crs}", f" and target_crs value: {self.target_crs}", ) cat_cr = self._create_climrcats(key=key, da=data_ss) return AggData( variable=key, cat_cr=cat_cr, da=data_ss, target_gdf=self.target_gdf.copy(), target_id=self.target_id, source_time_period=["None", "None"], )
def _check_xname(self: "UserTiffData") -> None: """Validate that the x-coordinate name exists in the dataset.""" if self.source_x_coord not in self.source_ds.coords: raise ValueError(f"xname:{self.source_x_coord} not in {self.source_ds.coords}") def _check_yname(self: "UserTiffData") -> None: """Validate that the y-coordinate name exists in the dataset.""" if self.source_y_coord not in self.source_ds.coords: raise ValueError(f"yname:{self.source_y_coord} not in {self.source_ds.coords}") def _check_band(self: "UserTiffData") -> None: """Validate that the band coordinate name exists in the dataset.""" if self.bname not in self.source_ds.coords: raise ValueError(f"band:{self.bname} not in {self.source_ds.coords} or {self.source_ds.data_vars}") def _check_crs(self: "UserTiffData") -> None: """Validate that the source and target CRS are valid.""" crs = CRS.from_user_input(self.source_crs) if not isinstance(crs, CRS): raise ValueError(f"ds_proj:{self.source_crs} not in valid") crs2 = CRS.from_user_input(self.target_crs) if not isinstance(crs2, CRS): raise ValueError(f"ds_proj:{self.target_crs} not in valid") def _create_climrcats(self: "UserTiffData", key: str, da: xr.DataArray) -> CatClimRItem: """Create a CatClimRItem instance from the raster metadata. This helper method constructs a `CatClimRItem` object, which is used internally to standardize metadata for processing. Args: key (str): The variable name. da (xr.DataArray): The DataArray for the variable, used to extract metadata. Returns: A `CatClimRItem` instance containing the metadata. """ return CatClimRItem( # id=self.id, URL="", varname=key, long_name=str(self._get_ds_var_attrs(da, "description")), units=str(self._get_ds_var_attrs(da, "units")), X_name=self.source_x_coord, Y_name=self.source_y_coord, proj=str(self.source_crs), resX=max(np.diff(da[self.source_x_coord].values)), resY=max(np.diff(da[self.source_y_coord].values)), toptobottom=str(_get_top_to_bottom(da, self.source_y_coord)), ) def _get_ds_var_attrs(self: "UserTiffData", da: xr.DataArray, attr: str) -> str: """Get a specific attribute from a DataArray. Safely retrieves an attribute from the DataArray's `attrs` dictionary. Args: da (xr.DataArray): The DataArray to inspect. attr (str): The attribute name to retrieve. Returns: The attribute value as a string, or "None" if not found. """ try: return str(da.attrs.get(attr)) except KeyError: return "None"