---
jupytext:
  formats: ipynb,md:myst
  text_representation:
    extension: .md
    format_name: myst
    format_version: 0.13
    jupytext_version: 1.18.1
kernelspec:
  display_name: Python 3 (ipykernel)
  language: python
  name: python3
---

# Zonal statistics of rasters

> **Deprecation Notice:** The `serial`, `parallel`, and `dask` zonal engines are
> deprecated and will be removed in a future version. Use `zonal_engine="exactextract"`
> instead — it provides better performance and handles large datasets with bounded memory.
> See the [ExactExtract Engine](#exactextract-engine-high-performance) section below.

```{code-cell} ipython3
import geopandas as gpd
import pandas as pd
import rioxarray as rxr
import xarray as xr

from gdptools import ZonalGen
from gdptools.data.user_data import UserTiffData
```

```{code-cell} ipython3
# Be sure to use rioxarray to read tiff
rds_slope = rxr.open_rasterio("../../../tests/data/rasters/slope/slope.tif")
rds_text = rxr.open_rasterio("../../../tests/data/rasters/TEXT_PRMS/TEXT_PRMS.tif")
# Use geopandas to read shape file,
# Best if shape file only has feature id, and geometry if possible.
gdf = gpd.read_file("../../../tests/data/Oahu.shp")
id_feature = "fid"
print(len(gdf.groupby(id_feature)))
# These params are used to fill out the TiffAttributes class
tx_name= 'x'
ty_name = 'y'
band = 'band'
crs = 26904
varname = "slope" # not currently used
categorical = False # is the data categorical or not, if the data are integers it should
                   # probably be categorical.
data = UserTiffData(
    source_var=varname,
    source_ds=rds_slope,
    source_crs=crs,
    source_x_coord=tx_name,
    source_y_coord=ty_name,
    band=1,
    bname=band,
    target_gdf=gdf,
    target_id=id_feature
)

zonal_gen = ZonalGen(
    user_data=data,
    zonal_engine="serial",
    zonal_writer="csv",
    out_path=".",
    file_prefix="Oahu_slope_stats"
)
stats = zonal_gen.calculate_zonal(categorical=categorical)
print(stats)

# stats = zonal_gen.calculate_zonal(categorical=True)
# print(stats)
```

```{code-cell} ipython3
gdf.sort_values(by=id_feature, inplace=True)
stats.sort_values(by=id_feature, inplace=True)
gdf["slope_mean"] = stats["mean"].values
gdf
```

```{code-cell} ipython3
gdf.plot(column="slope_mean", legend=True)
```

## Parallel generator

In this case the generator is slower because the domain is small and the overhead of generating the parallel processes takes more time than the actual calculation. However with a large domain the "parallel" zonal_engine is faster and has a smaller memory footprint.

```{code-cell} ipython3
zonal_gen2 = ZonalGen(
    user_data=data,
    zonal_engine="parallel",
    zonal_writer="csv",
    out_path=".",
    file_prefix="Oahu_slope_stats_p",
    jobs=2
)
stats = zonal_gen2.calculate_zonal(categorical=categorical)
print(stats)
```

```{code-cell} ipython3
gdf.sort_values(by=id_feature, inplace=True)
stats.sort_values(by=id_feature, inplace=True)
gdf["slope_mean"] = stats["mean"].values
gdf
```

```{code-cell} ipython3
gdf.plot(column="slope_mean", legend=True)
```

(exactextract-engine-high-performance)=
## ExactExtract Engine (High Performance)

gdptools now supports the `exactextract` engine, which provides **10-100x faster** zonal statistics by operating directly on raster data without converting to vector polygons.

### Engine Comparison

| Engine | Status | Description | Best For | Weighted Stats |
|--------|--------|-------------|----------|----------------|
| `exactextract` | **Recommended** | Raster-native C++ library | **Any size, fastest** | Planned ([#94](https://code.usgs.gov/wma/nhgf/toolsteam/gdptools/-/issues/94)) |
| `serial` | Deprecated | Single-threaded, vector-based | Small datasets, debugging | Yes |
| `parallel` | Deprecated | Multi-threaded via joblib | Medium datasets | Yes |
| `dask` | Deprecated | Distributed via Dask cluster | Very large datasets | Yes |

### Why is exactextract faster?

1. **Raster-native**: Operates directly on raster blocks without creating polygon geometries
2. **Coverage fractions**: Computes cell coverage analytically (no expensive intersection calls)
3. **Streaming**: Processes raster blocks on-demand without full array materialization
4. **C++ backend**: Core algorithms implemented in optimized C++

**Note**: `exactextract` does not yet support weighted zonal statistics. Weighted support via exactextract's built-in coverage fractions is planned in [issue #94](https://code.usgs.gov/wma/nhgf/toolsteam/gdptools/-/issues/94). Until then, the deprecated engines remain available for `WeightedZonalGen`.

```{code-cell} ipython3
import time

# Reload data for fair comparison
data_slope = UserTiffData(
    source_var="slope",
    source_ds=rds_slope,
    source_crs=crs,
    source_x_coord=tx_name,
    source_y_coord=ty_name,
    band=1,
    bname=band,
    target_gdf=gdf[[id_feature, "geometry"]].copy(),
    target_id=id_feature
)

# ExactExtract engine - fastest option
zonal_gen_exact = ZonalGen(
    user_data=data_slope,
    zonal_engine="exactextract",  # New high-performance engine!
    zonal_writer="csv",
    out_path=".",
    file_prefix="Oahu_slope_stats_exact"
)

start = time.perf_counter()
stats_exact = zonal_gen_exact.calculate_zonal(categorical=False)
exact_time = time.perf_counter() - start
print(f"\nexactextract completed in {exact_time:.4f} seconds")
print(stats_exact.head())
```

### Timing Comparison: All Engines

Let's compare the performance of all three engines on the same dataset:

```{code-cell} ipython3
timing_results = {}

# Serial engine
zonal_serial = ZonalGen(
    user_data=data_slope,
    zonal_engine="serial",
    zonal_writer="csv",
    out_path=".",
    file_prefix="timing_serial"
)
start = time.perf_counter()
_ = zonal_serial.calculate_zonal(categorical=False)
timing_results["serial"] = time.perf_counter() - start

# Parallel engine
zonal_parallel = ZonalGen(
    user_data=data_slope,
    zonal_engine="parallel",
    zonal_writer="csv",
    out_path=".",
    file_prefix="timing_parallel",
    jobs=4
)
start = time.perf_counter()
_ = zonal_parallel.calculate_zonal(categorical=False)
timing_results["parallel"] = time.perf_counter() - start

# ExactExtract engine
zonal_exact = ZonalGen(
    user_data=data_slope,
    zonal_engine="exactextract",
    zonal_writer="csv",
    out_path=".",
    file_prefix="timing_exact"
)
start = time.perf_counter()
_ = zonal_exact.calculate_zonal(categorical=False)
timing_results["exactextract"] = time.perf_counter() - start

# Display results
print("\n" + "="*50)
print("TIMING COMPARISON")
print("="*50)
for engine, t in timing_results.items():
    speedup = timing_results["serial"] / t
    print(f"{engine:15} : {t:8.4f} sec  ({speedup:5.1f}x vs serial)")
print("="*50)
```

### Key Takeaways

- **exactextract** is the recommended engine — typically **10-100x faster** than vector-based engines because it operates directly on raster blocks without constructing intersection geometries.
- The **serial**, **parallel**, and **dask** engines are **deprecated** and will be removed in a future version.
- Use **exactextract** for all new zonal statistics work (mean, sum, min, max, etc.).
- **Weighted zonal statistics** via exactextract are planned ([issue #94](https://code.usgs.gov/wma/nhgf/toolsteam/gdptools/-/issues/94)). Until then, `WeightedZonalGen` still uses the deprecated engines.

```{code-cell} ipython3
varname = "TEXT" # not currently used
categorical = True # is the data categorical or not, if the data are integers it should
                   # probably be categorical.
data = UserTiffData(
    source_var=varname,
    source_ds=rds_text,  # Use rds_text for categorical example
    source_crs=crs,
    source_x_coord=tx_name,
    source_y_coord=ty_name,
    band=1,
    bname=band,
    target_gdf=gdf,
    target_id=id_feature
)

zonal_gen = ZonalGen(
    user_data=data,
    zonal_engine="serial",
    zonal_writer="csv",
    out_path=".",
    file_prefix="Oahu_TEXT_stats"
)
stats = zonal_gen.calculate_zonal(categorical=categorical)
print(stats)
```

```{code-cell} ipython3
# Compute the max category for each `fid`
max_category_df = stats.iloc[:, :-1].idxmax(axis=1).reset_index()
max_category_df.columns = ['fid', 'max_category']
```

```{code-cell} ipython3
# Merge the max category result into the GeoDataFrame
gdf = gdf.merge(max_category_df, on='fid', how='left')
gdf
```

```{code-cell} ipython3
gdf.plot(column="max_category", legend=True)
```

## Parallel engine

```{code-cell} ipython3
zonal_gen2 = ZonalGen(
    user_data=data,
    zonal_engine="parallel",
    zonal_writer="csv",
    out_path=".",
    file_prefix="Oahu_TEXT_stats",
    jobs=2
)
stats = zonal_gen2.calculate_zonal(categorical=categorical)
print(stats)
```

```{code-cell} ipython3
# Compute the max category for each `fid`
max_category_df = stats.iloc[:, :-1].idxmax(axis=1).reset_index()
max_category_df.columns = ['fid', 'max_category_p']

gdf = gdf.merge(max_category_df, on='fid', how='left')
gdf
```

```{code-cell} ipython3
gdf.plot(column="max_category_p", legend=True)
```

## ExactExtract Engine with Categorical Data

The `exactextract` engine also supports categorical zonal statistics with the `categorical=True` parameter. For categorical data, it computes the count/area of each unique category within each polygon.

```{code-cell} ipython3
# Set up categorical data using TEXT raster
data_text = UserTiffData(
    source_var="TEXT",
    source_ds=rds_text,
    source_crs=crs,
    source_x_coord=tx_name,
    source_y_coord=ty_name,
    band=1,
    bname=band,
    target_gdf=gdf[[id_feature, "geometry"]].copy(),
    target_id=id_feature
)

# ExactExtract engine with categorical data
zonal_gen_exact_cat = ZonalGen(
    user_data=data_text,
    zonal_engine="exactextract",
    zonal_writer="csv",
    out_path=".",
    file_prefix="Oahu_TEXT_stats_exact"
)

start = time.perf_counter()
stats_exact_cat = zonal_gen_exact_cat.calculate_zonal(categorical=True)
exact_cat_time = time.perf_counter() - start
print(f"\nexactextract (categorical) completed in {exact_cat_time:.4f} seconds")
print(stats_exact_cat.head())
```

```{code-cell} ipython3
# Compute the max category for each polygon
max_category_exact = stats_exact_cat.iloc[:, :-1].idxmax(axis=1).reset_index()
max_category_exact.columns = ['fid', 'max_category_exact']

# Merge into GeoDataFrame and plot
gdf = gdf.merge(max_category_exact, on='fid', how='left')
gdf.plot(column="max_category_exact", legend=True)
```
