"""Helper functions for data subsetting and validation.
This module provides utility functions that support the core functionality of
gdptools by providing:
- Spatial and temporal subsetting for xarray Datasets and DataArrays.
- Validation checks for gridded data dimensions.
"""
from __future__ import annotations
import logging
from collections.abc import Mapping
from typing import Any
import numpy as np
import numpy.typing as npt
import pystac
import xarray as xr
logger = logging.getLogger(__name__)
# NHGF STAC catalog URL
NHGF_STAC_CATALOG_URL = "https://api.water.usgs.gov/gdp/pygeoapi/stac/stac-collection/"
[docs]
class STACCatalogError(GDPToolsError):
"""Exception raised when STAC catalog operations fail."""
pd_offset_conv: dict[str, str] = {
"years": "Y",
"months": "M",
"days": "D",
"hours": "H",
}
[docs]
def build_subset(
bounds: npt.NDArray[np.double],
xname: str,
yname: str,
tname: str,
toptobottom: bool,
date_min: str | None = None,
date_max: str | None = None,
) -> dict[str, object]:
"""Create a dictionary to use with xarray .sel() method to subset by time and space.
Constructs a selection dictionary for xarray subsetting operations that handles
both spatial (x, y) and temporal (time) dimensions. Automatically adjusts for
coordinate system orientation and provides flexible time range selection.
Args:
bounds: Spatial bounds array in format [minx, miny, maxx, maxy].
xname: Name of the x-dimension in the dataset.
yname: Name of the y-dimension in the dataset.
tname: Name of the time dimension in the dataset.
toptobottom: If True, y-coordinates increase from north to south. If False,
y-coordinates increase from south to north.
date_min: Start date for temporal subset (ISO format string). If None,
no temporal subsetting is applied.
date_max: End date for temporal subset (ISO format string). If None and
date_min is provided, only the exact date_min is selected.
Returns:
Dictionary containing slice objects for xarray .sel() method with keys
corresponding to dimension names and values as slice objects or exact values.
Examples:
>>> bounds = np.array([-180, -90, 180, 90])
>>> subset_dict = build_subset(
... bounds, 'longitude', 'latitude', 'time', False,
... '2020-01-01', '2020-12-31'
... )
>>> data_subset = dataset.sel(subset_dict)
"""
minx = bounds[0]
maxx = bounds[2]
miny = bounds[1]
maxy = bounds[3]
if not toptobottom:
if date_max is None and date_min is None:
return {
xname: slice(minx, maxx),
yname: slice(maxy, miny),
}
elif date_max is None:
return {
xname: slice(minx, maxx),
yname: slice(maxy, miny),
tname: date_min,
}
else:
return {
xname: slice(minx, maxx),
yname: slice(maxy, miny),
tname: slice(date_min, date_max),
}
elif date_max is None and date_min is None:
return {
xname: slice(minx, maxx),
yname: slice(miny, maxy),
}
elif date_max is None:
return {
xname: slice(minx, maxx),
yname: slice(miny, maxy),
tname: date_min,
}
else:
return {
xname: slice(minx, maxx),
yname: slice(miny, maxy),
tname: slice(date_min, date_max),
}
[docs]
def build_subset_tiff(
bounds: npt.NDArray[np.double],
xname: str,
yname: str,
toptobottom: bool,
bname: str,
band: int,
) -> Mapping[Any, Any]:
"""Create a dictionary to use with xarray .sel() method to subset TIFF data by space and band.
Constructs a selection dictionary for xarray subsetting operations specifically
for TIFF/raster data that handles spatial (x, y) dimensions and band selection.
Automatically adjusts for coordinate system orientation.
Args:
bounds: Spatial bounds array in format [minx, miny, maxx, maxy].
xname: Name of the x-dimension in the dataset.
yname: Name of the y-dimension in the dataset.
toptobottom: If True, y-coordinates increase from north to south. If False,
y-coordinates increase from south to north.
bname: Name of the band dimension in the dataset.
band: Specific band number to select.
Returns:
Dictionary containing slice objects for xarray .sel() method with keys
corresponding to dimension names and values as slice objects or exact values.
Examples:
>>> bounds = np.array([-180, -90, 180, 90])
>>> subset_dict = build_subset_tiff(
... bounds, 'x', 'y', True, 'band', 1
... )
>>> raster_subset = raster_data.sel(subset_dict)
"""
minx = bounds[0]
maxx = bounds[2]
miny = bounds[1]
maxy = bounds[3]
return (
{
xname: slice(minx, maxx),
yname: slice(miny, maxy),
bname: band,
}
if toptobottom
else {
xname: slice(minx, maxx),
yname: slice(maxy, miny),
bname: band,
}
)
[docs]
def build_subset_tiff_da(
bounds: npt.NDArray[np.double],
xname: str,
yname: str,
toptobottom: int | bool,
) -> Mapping[Any, Any]:
"""Create a dictionary to use with xarray .sel() method to subset TIFF DataArray by space.
Constructs a selection dictionary for xarray subsetting operations specifically
for TIFF/raster DataArray objects that handles spatial (x, y) dimensions.
Automatically adjusts for coordinate system orientation.
Args:
bounds: Spatial bounds array in format [minx, miny, maxx, maxy].
xname: Name of the x-dimension in the dataset.
yname: Name of the y-dimension in the dataset.
toptobottom: If True or 1, y-coordinates increase from north to south.
If False or 0, y-coordinates increase from south to north.
Returns:
Dictionary containing slice objects for xarray .sel() method with keys
corresponding to dimension names and values as slice objects.
Examples:
>>> bounds = np.array([-180, -90, 180, 90])
>>> subset_dict = build_subset_tiff_da(
... bounds, 'x', 'y', True
... )
>>> raster_subset = raster_dataarray.sel(subset_dict)
"""
minx = bounds[0]
maxx = bounds[2]
miny = bounds[1]
maxy = bounds[3]
return (
{
xname: slice(minx, maxx),
yname: slice(miny, maxy),
}
if toptobottom
else {
xname: slice(minx, maxx),
yname: slice(maxy, miny),
}
)
[docs]
def check_gridded_data_for_dimensions(ds: xr.Dataset, vars: list[str]) -> None:
"""Check that gridded data has the required dimensions.
Checks each specified DataArray in an xarray Dataset to confirm that it
has three dimensions and that the first dimension is 'time'. This is a
pre-requisite for many gdptools processing functions.
Args:
ds: The xarray Dataset to validate.
vars: A list of variable names within the dataset to check.
Raises:
KeyError: If any of the specified variables do not have exactly
three dimensions or if 'time' is not the first dimension.
"""
bad_vars = []
for var in vars:
da = ds[var]
if len(da.shape) == 3:
if next(iter(da.indexes)) == "time":
continue
else:
bad_vars.append(var)
if bad_vars:
raise KeyError(
"Cannot process these DataArrays because their dimensions do not match the "
f"requirements of GDPtools: {bad_vars}"
)
[docs]
def get_stac_collection(collection_id: str) -> pystac.Collection:
"""Fetch a collection from the NHGF STAC catalog.
Attempts a direct API lookup first (single HTTP request). Falls back to
a recursive catalog traversal for nested collections whose API path
doesn't match their ID.
Args:
collection_id: The collection identifier (e.g., ``"conus404_daily"``,
``"nlcd-LndCov"``).
Returns:
The pystac Collection object.
Raises:
STACCatalogError: If the collection is not found.
"""
base = NHGF_STAC_CATALOG_URL.rstrip("/")
# --- Fast path: direct API lookup (1-2 HTTP requests) ---
# The pygeoapi STAC API serves collections at {base}/{parent}/{child}.
# Try the ID as a direct child first, then scan top-level parents for
# a two-segment path.
direct_urls = [f"{base}/{collection_id}"]
try:
catalog = pystac.Catalog.from_file(NHGF_STAC_CATALOG_URL)
except Exception as e:
raise STACCatalogError(f"Failed to load NHGF STAC catalog: {e}") from e
# Build candidate URLs from top-level child link hrefs
for link in catalog.get_child_links():
parent_id = link.href.rstrip("/").split("/")[-1]
if parent_id != collection_id:
direct_urls.append(f"{base}/{parent_id}/{collection_id}")
for url in direct_urls:
try:
obj = pystac.Collection.from_file(url)
if obj.id == collection_id:
return obj
except Exception:
logger.debug("Failed to resolve STAC link", exc_info=True)
continue
# --- Slow path: recursive traversal (fallback) ---
def find_collection(
parent: pystac.Catalog | pystac.Collection,
target_id: str,
depth: int = 0,
max_depth: int = 5,
) -> pystac.Collection | None:
"""Recursively search for a collection by ID."""
if depth > max_depth:
return None
# Iterate children one at a time so a single broken STAC link
# doesn't abort the entire search.
for link in parent.get_child_links():
try:
link.resolve_stac_object(root=parent.get_root())
child = link.target
except Exception:
logger.debug("Failed to resolve STAC child link", exc_info=True)
continue
if child.id == target_id:
return child
result = find_collection(child, target_id, depth + 1, max_depth)
if result is not None:
return result
return None
collection = find_collection(catalog, collection_id)
if collection is not None:
return collection
raise STACCatalogError(
f"Collection '{collection_id}' not found in NHGF STAC catalog. "
f"Use the catalog URL to browse available collections: {NHGF_STAC_CATALOG_URL}"
)