Data Validation with Pyspark Pandas¶
new in 0.10.0
Pyspark is a
distributed compute framework that offers a pandas drop-in replacement dataframe
implementation via the pyspark.pandas API .
You can use pandera to validate DataFrame()
and Series()
objects directly. First, install
pandera
with the pyspark
extra:
pip install 'pandera[pyspark]'
Then you can use pandera schemas to validate pyspark dataframes. In the example
below we’ll use the class-based API to define a
DataFrameModel
for validation.
import pyspark.pandas as ps
import pandas as pd
import pandera as pa
from pandera.typing.pyspark import DataFrame, Series
class Schema(pa.DataFrameModel):
state: Series[str]
city: Series[str]
price: Series[int] = pa.Field(in_range={"min_value": 5, "max_value": 20})
# create a pyspark.pandas dataframe that's validated on object initialization
df = DataFrame[Schema](
{
'state': ['FL','FL','FL','CA','CA','CA'],
'city': [
'Orlando',
'Miami',
'Tampa',
'San Francisco',
'Los Angeles',
'San Diego',
],
'price': [8, 12, 10, 16, 20, 18],
}
)
print(df)
/home/docs/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched.
warnings.warn(
JAVA_HOME is not set
---------------------------------------------------------------------------
PySparkRuntimeError Traceback (most recent call last)
Cell In[1], line 15
11 price: Series[int] = pa.Field(in_range={"min_value": 5, "max_value": 20})
14 # create a pyspark.pandas dataframe that's validated on object initialization
---> 15 df = DataFrame[Schema](
16 {
17 'state': ['FL','FL','FL','CA','CA','CA'],
18 'city': [
19 'Orlando',
20 'Miami',
21 'Tampa',
22 'San Francisco',
23 'Los Angeles',
24 'San Diego',
25 ],
26 'price': [8, 12, 10, 16, 20, 18],
27 }
28 )
29 print(df)
File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pandera/typing/common.py:136, in __patched_generic_alias_call(self, *args, **kwargs)
131 if not self._inst:
132 raise TypeError(
133 f"Type {self._name} cannot be instantiated; "
134 f"use {self.__origin__.__name__}() instead"
135 )
--> 136 result = self.__origin__(*args, **kwargs)
137 try:
138 result.__orig_class__ = self
File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/pandas/frame.py:573, in DataFrame.__init__(self, data, index, columns, dtype, copy)
570 index = index._to_pandas()
572 pdf = pd.DataFrame(data=data, index=index, columns=columns, dtype=dtype, copy=copy)
--> 573 internal = InternalFrame.from_pandas(pdf)
574 index_assigned = True
576 if index is not None and not index_assigned:
577 # TODO(SPARK-40226): Support MultiIndex
File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/pandas/internal.py:1532, in InternalFrame.from_pandas(pdf)
1522 (
1523 pdf,
1524 index_columns,
(...)
1527 data_fields,
1528 ) = InternalFrame.prepare_pandas_frame(pdf, prefer_timestamp_ntz=prefer_timestamp_ntz)
1530 schema = StructType([field.struct_field for field in index_fields + data_fields])
-> 1532 sdf = default_session().createDataFrame(pdf, schema=schema)
1533 return InternalFrame(
1534 spark_frame=sdf,
1535 index_spark_columns=[scol_for(sdf, col) for col in index_columns],
(...)
1541 column_label_names=column_label_names,
1542 )
File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/pandas/utils.py:483, in default_session()
481 spark = SparkSession.getActiveSession()
482 if spark is None:
--> 483 spark = SparkSession.builder.appName("pandas-on-Spark").getOrCreate()
485 # Turn ANSI off when testing the pandas API on Spark since
486 # the behavior of pandas API on Spark follows pandas, not SQL.
487 if is_testing():
File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/sql/session.py:497, in SparkSession.Builder.getOrCreate(self)
495 sparkConf.set(key, value)
496 # This SparkContext may be an existing one.
--> 497 sc = SparkContext.getOrCreate(sparkConf)
498 # Do not update `SparkConf` for existing `SparkContext`, as it's shared
499 # by all sessions.
500 session = SparkSession(sc, options=self._options)
File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/context.py:515, in SparkContext.getOrCreate(cls, conf)
513 with SparkContext._lock:
514 if SparkContext._active_spark_context is None:
--> 515 SparkContext(conf=conf or SparkConf())
516 assert SparkContext._active_spark_context is not None
517 return SparkContext._active_spark_context
File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/context.py:201, in SparkContext.__init__(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, conf, gateway, jsc, profiler_cls, udf_profiler_cls, memory_profiler_cls)
195 if gateway is not None and gateway.gateway_parameters.auth_token is None:
196 raise ValueError(
197 "You are trying to pass an insecure Py4j gateway to Spark. This"
198 " is not allowed as it is a security risk."
199 )
--> 201 SparkContext._ensure_initialized(self, gateway=gateway, conf=conf)
202 try:
203 self._do_init(
204 master,
205 appName,
(...)
215 memory_profiler_cls,
216 )
File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/context.py:436, in SparkContext._ensure_initialized(cls, instance, gateway, conf)
434 with SparkContext._lock:
435 if not SparkContext._gateway:
--> 436 SparkContext._gateway = gateway or launch_gateway(conf)
437 SparkContext._jvm = SparkContext._gateway.jvm
439 if instance:
File ~/checkouts/readthedocs.org/user_builds/pandera/envs/latest/lib/python3.10/site-packages/pyspark/java_gateway.py:107, in launch_gateway(conf, popen_kwargs)
104 time.sleep(0.1)
106 if not os.path.isfile(conn_info_file):
--> 107 raise PySparkRuntimeError(
108 error_class="JAVA_GATEWAY_EXITED",
109 message_parameters={},
110 )
112 with open(conn_info_file, "rb") as info:
113 gateway_port = read_int(info)
PySparkRuntimeError: [JAVA_GATEWAY_EXITED] Java gateway process exited before sending its port number.
You can also use the check_types()
decorator to validate
pyspark pandas dataframes at runtime:
@pa.check_types
def function(df: DataFrame[Schema]) -> DataFrame[Schema]:
return df[df["state"] == "CA"]
print(function(df))
And of course, you can use the object-based API to validate dask dataframes:
schema = pa.DataFrameSchema({
"state": pa.Column(str),
"city": pa.Column(str),
"price": pa.Column(int, pa.Check.in_range(min_value=5, max_value=20))
})
schema(df)