DecoratorsΒΆ

check_input / check_outputΒΆ

These accept any schema object, including {class}~pandera.api.xarray.container.DataArraySchemaandDatasetSchema, and validate function arguments or return values:

import numpy as np
import xarray as xr
import pandera.xarray as pa

schema = pa.DataArraySchema(dtype=np.float64, dims=("x",))

@pa.check_input(schema, "da")
def process(da: xr.DataArray) -> xr.DataArray:
    return da * 2

@pa.check_output(schema)
def generate() -> xr.DataArray:
    return xr.DataArray(np.ones(3), dims="x")

da = xr.DataArray(np.array([1.0, 2.0, 3.0]), dims="x")
process(da)
<xarray.DataArray (x: 3)> Size: 24B
array([2., 4., 6.])
Dimensions without coordinates: x
generate()
<xarray.DataArray (x: 3)> Size: 24B
array([1., 1., 1.])
Dimensions without coordinates: x

Validation errors are raised if the input or output doesn’t match:

bad_da = xr.DataArray(np.zeros(3), dims=("z",))

try:
    process(bad_da)
except pa.errors.SchemaError as exc:
    print(exc)
dim position 0: expected 'x', got 'z'

check_ioΒΆ

check_io combines input and output validation in a single decorator. Pass keyword arguments matching parameter names for inputs and out for the return value:

in_schema = pa.DataArraySchema(dtype=np.float64, dims=("x",))
out_schema = pa.DataArraySchema(dtype=np.float64, dims=("x",))

@pa.check_io(da=in_schema, out=out_schema)
def scale(da: xr.DataArray) -> xr.DataArray:
    return da * 10

da = xr.DataArray(np.array([1.0, 2.0, 3.0]), dims="x")
scale(da)
<xarray.DataArray (x: 3)> Size: 24B
array([10., 20., 30.])
Dimensions without coordinates: x

check_typesΒΆ

check_types inspects type annotations and validates against the referenced model. Use the generic types from pandera.typing.xarray:

from pandera.typing.xarray import Coordinate, DataArray, Dataset

class Temperature(pa.DataArrayModel):
    data: np.float64 = pa.Field()
    x: Coordinate[np.float64]

    class Config:
        dims = ("x",)
        name = "temperature"

@pa.check_types
def transform(da: DataArray[Temperature]) -> DataArray[Temperature]:
    return da * 2

da = xr.DataArray(
    np.ones(5),
    dims="x",
    coords={"x": np.arange(5, dtype=np.float64)},
    name="temperature",
)
transform(da)
<xarray.DataArray 'temperature' (x: 5)> Size: 40B
array([2., 2., 2., 2., 2.])
Coordinates:
  * x        (x) float64 40B 0.0 1.0 2.0 3.0 4.0

For datasets:

class Surface(pa.DatasetModel):
    temperature: np.float64 = pa.Field(dims=("x",))
    x: Coordinate[np.float64]

@pa.check_types
def process_dataset(ds: Dataset[Surface]) -> Dataset[Surface]:
    return ds

ds = xr.Dataset(
    {"temperature": (("x",), np.ones(3))},
    coords={"x": np.arange(3, dtype=np.float64)},
)
process_dataset(ds)
<xarray.Dataset> Size: 48B
Dimensions:      (x: 3)
Coordinates:
  * x            (x) float64 24B 0.0 1.0 2.0
Data variables:
    temperature  (x) float64 24B 1.0 1.0 1.0

Mixed annotations work too β€” for example a function that takes a DataArray and returns a Dataset:

@pa.check_types
def to_dataset(da: DataArray[Temperature]) -> Dataset[Surface]:
    return xr.Dataset(
        {"temperature": da},
        coords={"x": da.coords["x"]},
    )

to_dataset(da)
<xarray.Dataset> Size: 80B
Dimensions:      (x: 5)
Coordinates:
  * x            (x) float64 40B 0.0 1.0 2.0 3.0 4.0
Data variables:
    temperature  (x) float64 40B 1.0 1.0 1.0 1.0 1.0

Pass lazy=True to collect all validation errors instead of failing on the first one:

@pa.check_types(lazy=True)
def strict_transform(da: DataArray[Temperature]) -> DataArray[Temperature]:
    return xr.DataArray(np.ones(3), dims=("z",), name="bad")

try:
    strict_transform(da)
except pa.errors.SchemaErrors as exc:
    print(exc)
{
    "SCHEMA": {
        "WRONG_FIELD_NAME": [
            {
                "schema": "temperature",
                "column": "temperature",
                "check": "name",
                "error": "expected name 'temperature', got 'bad'"
            }
        ],
        "MISMATCH_INDEX": [
            {
                "schema": "temperature",
                "column": "temperature",
                "check": "dims",
                "error": "dim position 0: expected 'x', got 'z'"
            }
        ],
        "COLUMN_NOT_IN_DATAFRAME": [
            {
                "schema": "temperature",
                "column": "x",
                "check": "coords",
                "error": "missing coordinate 'x'"
            }
        ]
    }
}

For DataTree models:

from pandera.typing.xarray import DataTree

class SurfaceDS(pa.DatasetModel):
    temperature: np.float64 = pa.Field(dims=("x",))
    x: Coordinate[np.float64]

class MyTree(pa.DataTreeModel):
    surface: SurfaceDS

@pa.check_types
def process_tree(dt: DataTree[MyTree]) -> DataTree[MyTree]:
    return dt

dt = xr.DataTree.from_dict({
    "/": xr.Dataset(),
    "/surface": xr.Dataset(
        {"temperature": (("x",), np.ones(3))},
        coords={"x": np.arange(3, dtype=np.float64)},
    ),
})
process_tree(dt)
<xarray.DataTree>
Group: /
└── Group: /surface
        Dimensions:      (x: 3)
        Coordinates:
          * x            (x) float64 24B 0.0 1.0 2.0
        Data variables:
            temperature  (x) float64 24B 1.0 1.0 1.0

See Decorators for Pipeline Integration for the full decorator API.

See alsoΒΆ