Data Validation with Pyspark Pandas

new in 0.10.0

Pyspark is a distributed compute framework that offers a pandas drop-in replacement dataframe implementation via the pyspark.pandas API . You can use pandera to validate DataFrame() and Series() objects directly. First, install pandera with the pyspark extra:

pip install 'pandera[pyspark]'

Then you can use pandera schemas to validate pyspark dataframes. In the example below we’ll use the class-based API to define a DataFrameModel for validation.

import pyspark.pandas as ps
import pandas as pd
import pandera as pa

from pandera.typing.pyspark import DataFrame, Series


class Schema(pa.DataFrameModel):
    state: Series[str]
    city: Series[str]
    price: Series[int] = pa.Field(in_range={"min_value": 5, "max_value": 20})


# create a pyspark.pandas dataframe that's validated on object initialization
df = DataFrame[Schema](
    {
        'state': ['FL','FL','FL','CA','CA','CA'],
        'city': [
            'Orlando',
            'Miami',
            'Tampa',
            'San Francisco',
            'Los Angeles',
            'San Diego',
        ],
        'price': [8, 12, 10, 16, 20, 18],
    }
)
print(df)
/home/docs/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.
  warnings.warn(
JAVA_HOME is not set
---------------------------------------------------------------------------
PySparkRuntimeError                       Traceback (most recent call last)
Cell In[1], line 15
     11     price: Series[int] = pa.Field(in_range={"min_value": 5, "max_value": 20})
     14 # create a pyspark.pandas dataframe that's validated on object initialization
---> 15 df = DataFrame[Schema](
     16     {
     17         'state': ['FL','FL','FL','CA','CA','CA'],
     18         'city': [
     19             'Orlando',
     20             'Miami',
     21             'Tampa',
     22             'San Francisco',
     23             'Los Angeles',
     24             'San Diego',
     25         ],
     26         'price': [8, 12, 10, 16, 20, 18],
     27     }
     28 )
     29 print(df)

File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pandera/typing/common.py:136, in __patched_generic_alias_call(self, *args, **kwargs)
    131 if not self._inst:
    132     raise TypeError(
    133         f"Type {self._name} cannot be instantiated; "
    134         f"use {self.__origin__.__name__}() instead"
    135     )
--> 136 result = self.__origin__(*args, **kwargs)
    137 try:
    138     result.__orig_class__ = self

File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/pandas/frame.py:573, in DataFrame.__init__(self, data, index, columns, dtype, copy)
    570         index = index._to_pandas()
    572     pdf = pd.DataFrame(data=data, index=index, columns=columns, dtype=dtype, copy=copy)
--> 573     internal = InternalFrame.from_pandas(pdf)
    574     index_assigned = True
    576 if index is not None and not index_assigned:
    577     # TODO(SPARK-40226): Support MultiIndex

File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/pandas/internal.py:1532, in InternalFrame.from_pandas(pdf)
   1522 (
   1523     pdf,
   1524     index_columns,
   (...)
   1527     data_fields,
   1528 ) = InternalFrame.prepare_pandas_frame(pdf, prefer_timestamp_ntz=prefer_timestamp_ntz)
   1530 schema = StructType([field.struct_field for field in index_fields + data_fields])
-> 1532 sdf = default_session().createDataFrame(pdf, schema=schema)
   1533 return InternalFrame(
   1534     spark_frame=sdf,
   1535     index_spark_columns=[scol_for(sdf, col) for col in index_columns],
   (...)
   1541     column_label_names=column_label_names,
   1542 )

File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/pandas/utils.py:483, in default_session()
    481 spark = SparkSession.getActiveSession()
    482 if spark is None:
--> 483     spark = SparkSession.builder.appName("pandas-on-Spark").getOrCreate()
    485 # Turn ANSI off when testing the pandas API on Spark since
    486 # the behavior of pandas API on Spark follows pandas, not SQL.
    487 if is_testing():

File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/sql/session.py:497, in SparkSession.Builder.getOrCreate(self)
    495     sparkConf.set(key, value)
    496 # This SparkContext may be an existing one.
--> 497 sc = SparkContext.getOrCreate(sparkConf)
    498 # Do not update `SparkConf` for existing `SparkContext`, as it's shared
    499 # by all sessions.
    500 session = SparkSession(sc, options=self._options)

File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/context.py:515, in SparkContext.getOrCreate(cls, conf)
    513 with SparkContext._lock:
    514     if SparkContext._active_spark_context is None:
--> 515         SparkContext(conf=conf or SparkConf())
    516     assert SparkContext._active_spark_context is not None
    517     return SparkContext._active_spark_context

File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/context.py:201, in SparkContext.__init__(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, conf, gateway, jsc, profiler_cls, udf_profiler_cls, memory_profiler_cls)
    195 if gateway is not None and gateway.gateway_parameters.auth_token is None:
    196     raise ValueError(
    197         "You are trying to pass an insecure Py4j gateway to Spark. This"
    198         " is not allowed as it is a security risk."
    199     )
--> 201 SparkContext._ensure_initialized(self, gateway=gateway, conf=conf)
    202 try:
    203     self._do_init(
    204         master,
    205         appName,
   (...)
    215         memory_profiler_cls,
    216     )

File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/context.py:436, in SparkContext._ensure_initialized(cls, instance, gateway, conf)
    434 with SparkContext._lock:
    435     if not SparkContext._gateway:
--> 436         SparkContext._gateway = gateway or launch_gateway(conf)
    437         SparkContext._jvm = SparkContext._gateway.jvm
    439     if instance:

File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/java_gateway.py:107, in launch_gateway(conf, popen_kwargs)
    104     time.sleep(0.1)
    106 if not os.path.isfile(conn_info_file):
--> 107     raise PySparkRuntimeError(
    108         error_class="JAVA_GATEWAY_EXITED",
    109         message_parameters={},
    110     )
    112 with open(conn_info_file, "rb") as info:
    113     gateway_port = read_int(info)

PySparkRuntimeError: [JAVA_GATEWAY_EXITED] Java gateway process exited before sending its port number.

You can also use the check_types() decorator to validate pyspark pandas dataframes at runtime:

@pa.check_types
def function(df: DataFrame[Schema]) -> DataFrame[Schema]:
    return df[df["state"] == "CA"]

print(function(df))

And of course, you can use the object-based API to validate dask dataframes:

schema = pa.DataFrameSchema({
    "state": pa.Column(str),
    "city": pa.Column(str),
    "price": pa.Column(int, pa.Check.in_range(min_value=5, max_value=20))
})
schema(df)