Extensions (new)

new in 0.6.0

Registering Custom Check Methods

One of the strengths of pandera is its flexibility in enabling you to defining in-line custom checks on the fly:

import pandera as pa

# checks elements in a column/dataframe
element_wise_check = pa.Check(lambda x: x < 0, element_wise=True)

# applies the check function to a dataframe/series
vectorized_check = pa.Check(lambda series_or_df: series_or_df < 0)

However, there are two main disadvantages of schemas with inline custom checks:

  1. they are not serializable with the IO interface.

  2. you can’t use them to synthesize data because the checks are not associated with a hypothesis strategy.

pandera now offers a way to register custom checks so that they’re available in the Check class as a check method. Here let’s define a custom method that checks whether a pandas object contains elements that lie within two values.

import pandera as pa
import pandera.extensions as extensions
import pandas as pd

@extensions.register_check_method(statistics=["min_value", "max_value"])
def is_between(pandas_obj, *, min_value, max_value):
    return (min_value <= pandas_obj) & (pandas_obj <= max_value)

schema = pa.DataFrameSchema({
    "col": pa.Column(int, pa.Check.is_between(min_value=1, max_value=10))
})

data = pd.DataFrame({"col": [1, 5, 10]})
print(schema(data))
   col
0    1
1    5
2   10

As you can see, a custom check’s first argument is a pandas series or dataframe by default (more on that later), followed by keyword-only arguments, specified with the * syntax.

The register_check_method() requires you to explicitly name the check statistics via the keyword argument, which are essentially the constraints placed by the check on the pandas data structure.

Specifying a Check Strategy

To specify a check strategy with your custom check, you’ll need to install the strategies extension. First let’s look at a trivially simple example, where the check verifies whether a column is equal to a certain value:

def custom_equals(pandas_obj, *, value):
    return pandas_obj == value

The corresponding strategy for this check would be:

from typing import Optional
import hypothesis
import pandera.strategies as st

def equals_strategy(
    pandas_dtype: pa.PandasDtype,
    strategy: Optional[st.SearchStrategy] = None,
    *,
    value,
):
    if strategy is None:
        return st.pandas_dtype_strategy(
            pandas_dtype, strategy=hypothesis.strategies.just(value),
        )
    return strategy.filter(lambda x: x == value)

As you may notice, the pandera strategy interface is has two arguments followed by keyword-only arguments that match the check function keyword-only check statistics. The pandas_dtype positional argument is useful for ensuring the correct data type. In the above example, we’re using the pandas_dtype_strategy() strategy to make sure the generated value is of the correct data type.

The optional strategy argument allows us to use the check strategy as a base strategy or a chained strategy. There’s a detail that we’re responsible for implementing in the strategy function body: we need to handle two cases to account for strategy chaining:

  1. when the strategy function is being used as a base strategy, i.e. when strategy is None

  2. when the strategy function is being chained from a previously-defined strategy, i.e. when strategy is not None.

Finally, to register the custom check with the strategy, use the register_check_method() decorator:

@extensions.register_check_method(
    statistics=["value"], strategy=equals_strategy
)
def custom_equals(pandas_obj, *, value):
    return pandas_obj == value

Let’s unpack what’s going in here. The custom_equals function only has a single statistic, which is the value argument, which we’ve also specified in register_check_method(). This means that the associated check strategy must match its keyword-only arguments.

Going back to our is_between function example, here’s what the strategy would look like:

def in_between_strategy(
    pandas_dtype: pa.PandasDtype,
    strategy: Optional[st.SearchStrategy] = None,
    *,
    min_value,
    max_value
):
    if strategy is None:
        return st.pandas_dtype_strategy(
            pandas_dtype,
            min_value=min_value,
            max_value=max_value,
            exclude_min=False,
            exclude_max=False,
        )
    return strategy.filter(lambda x: min_value <= x <= max_value)

@extensions.register_check_method(
    statistics=["min_value", "max_value"],
    strategy=in_between_strategy,
)
def is_between_with_strat(pandas_obj, *, min_value, max_value):
    return (min_value <= pandas_obj) & (pandas_obj <= max_value)

Check Types

The extensions module also supports registering element-wise and groupby checks.

Element-wise Checks

@extensions.register_check_method(
    statistics=["val"],
    check_type="element_wise",
)
def element_wise_equal_check(element, *, val):
    return element == val

Note that the first argument of element_wise_equal_check is a single element in the column or dataframe.

Groupby Checks

In this groupby check, we’re verifying that the values of one column for group_a are, on average, greater than those of group_b:

from typing import Dict

@extensions.register_check_method(
    statistics=["group_a", "group_b"],
    check_type="groupby",
)
def groupby_check(dict_groups: Dict[str, pd.Series], *, group_a, group_b):
    return dict_groups[group_a].mean() > dict_groups[group_b].mean()

data = pd.DataFrame({
    "values": [20, 10, 1, 15],
    "groups": list("xxyy"),
})

schema = pa.DataFrameSchema({
    "values": pa.Column(
        int,
        pa.Check.groupby_check(group_a="x", group_b="y", groupby="groups"),
    ),
    "groups": pa.Column(str),
})

print(schema(data))
   values groups
0      20      x
1      10      x
2       1      y
3      15      y

Registered Custom Checks with the Class-based API

Since registered checks are part of the Check namespace, you can also use custom checks with the class-based API:

from pandera.typing import Series

class Schema(pa.SchemaModel):
    col1: Series[str] = pa.Field(custom_equals="value")
    col2: Series[int] = pa.Field(is_between={"min_value": 0, "max_value": 10})

data = pd.DataFrame({
    "col1": ["value"] * 5,
    "col2": range(5)
})

print(Schema.validate(data))
    col1  col2
0  value     0
1  value     1
2  value     2
3  value     3
4  value     4