Data Validation with Pyspark Pandas¶
new in 0.10.0
Pyspark is a
distributed compute framework that offers a pandas drop-in replacement dataframe
implementation via the pyspark.pandas API .
You can use pandera to validate DataFrame()
and Series() objects directly. First, install
pandera with the pyspark extra:
pip install 'pandera[pyspark]'
Then you can use pandera schemas to validate pyspark dataframes. In the example
below we’ll use the class-based API to define a
DataFrameModel for validation.
import pyspark.pandas as ps
import pandas as pd
import pandera.pandas as pa
from pandera.typing.pyspark import DataFrame, Series
class Schema(pa.DataFrameModel):
state: Series[str]
city: Series[str]
price: Series[int] = pa.Field(in_range={"min_value": 5, "max_value": 20})
# create a pyspark.pandas dataframe that's validated on object initialization
df = DataFrame[Schema](
{
'state': ['FL','FL','FL','CA','CA','CA'],
'city': [
'Orlando',
'Miami',
'Tampa',
'San Francisco',
'Los Angeles',
'San Diego',
],
'price': [8, 12, 10, 16, 20, 18],
}
)
print(df)
WARNING: Using incubator modules: jdk.incubator.vector
Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
26/01/29 02:00:20 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
[Stage 0:> (0 + 2) / 2]
[Stage 5:> (0 + 2) / 2]
state city price
0 FL Orlando 8
1 FL Miami 12
2 FL Tampa 10
3 CA San Francisco 16
4 CA Los Angeles 20
5 CA San Diego 18
You can also use the check_types() decorator to validate
pyspark pandas dataframes at runtime:
@pa.check_types
def function(df: DataFrame[Schema]) -> DataFrame[Schema]:
return df[df["state"] == "CA"]
print(function(df))
state city price
3 CA San Francisco 16
4 CA Los Angeles 20
5 CA San Diego 18
And of course, you can use the object-based API to validate dask dataframes:
schema = pa.DataFrameSchema({
"state": pa.Column(str),
"city": pa.Column(str),
"price": pa.Column(int, pa.Check.in_range(min_value=5, max_value=20))
})
schema(df)
| state | city | price | |
|---|---|---|---|
| 0 | FL | Orlando | 8 |
| 1 | FL | Miami | 12 |
| 2 | FL | Tampa | 10 |
| 3 | CA | San Francisco | 16 |
| 4 | CA | Los Angeles | 20 |
| 5 | CA | San Diego | 18 |