Unit testing Spark Code using pytest library

  • Unit testing of spark code makes it extremely useful to find out functional problems and eliminate bugs before conducting performance test and finally migrating to production.
  • Spark's support for running in local mode makes it easier to conduct the unit test. One of the most widely used libraruy in python world is py.test, which is less verbose and provides great support for reusable fixtures and parametrization in fixtures.
  • In this tutorial we will demonstrate steps required to write unit test cases for pyspark applications.
  • We will use the captains performance spark code for this unit testing example. But we will rewrite those operations in terms of functions, which can be then tested using unit test functions.

Writing the spark application

  • Write the following code in a file called captains_module.py.
  • It defines two functions

    • getNumCaptainsByMinMatches() - Takes an rdd and returns number of captains who captained more then certain number of matches, which is also passed as a parameter.
    • getNumMatchesPerCountry() - Takes an rdd and returns total number of matches played by each country.

from collections import namedtuple
from pyspark import SparkContext

fields = ("name", "country", "career", "matches", "won", "lost", "ties", "toss" )

Captain = namedtuple( 'Captain', fields )


def parseRecs( line ):
  fields = line.split(",")
  return Captain( fields[0], fields[1], fields[2], int( fields[3] ),
                 int( fields[4] ), int(fields[5]), int(fields[6]), int(fields[7] ) )


def getNumCaptainsByMinMatches( anRDD, num_matches = 100 ):
  return anRDD.map( lambda rec: parseRecs( rec) )              \
  .filter( lambda rec: rec.matches > num_matches ).count()


def getNumMatchesPerCountry( anRDD ):
  return anRDD.map( lambda rec: parseRecs( rec) )       \
  .map( lambda rec: ( rec.country , rec.matches) )      \
  .reduceByKey( lambda a, b: a + b )                    \
  .sortBy( lambda rec: rec[1], ascending = False )

Write the Unit test cases

Now write the following code in another file called captains_test.py.


from captains_module import getNumCaptainsByMinMatches, getNumMatchesPerCountry
from pyspark import SparkConf, SparkContext
import pytest

test_input_data = ['Ponting  R T,Australia,1995-2012,230,165,51,14,124',
               'Fleming  S P,New Zealand,1994-2007,218,98,106,14,105',
               'Ranatunga  A,Sri Lanka,1982-1999,193,89,95,9,102',
               'Dhoni  M S*,India,2004-,186,103,68,15,88',
               'Border  A R,Australia,1979-1994,178,107,67,4,86',
               'Crowe  M D,New Zealand,1982-1995,44,21,22,1,28',
               'Atherton  M A,England,1990-1998,43,20,21,2,20',
               'Walsh  C A,West Indies,1985-2000,43,22,20,1,20']

@pytest.fixture(scope="session")

def spark_context(request):

  conf = (SparkConf().setMaster("local[2]").setAppName("captains-unittest"))
  sc = SparkContext(conf=conf)
  request.addfinalizer(lambda: sc.stop())

  return sc

pytestmark = pytest.mark.usefixtures("spark_context")

# Test case for number of captains with minimum number of mataches captained
def test_get_num_captains(spark_context):
  test_input_rdd = spark_context.parallelize( test_input_data, 2 )
  assert getNumCaptainsByMinMatches( test_input_rdd, 100 ) == 5, "Incorrect"

# Test case for Number of matached played by each country
def test_get_num_matches_per_country(spark_context):
  test_input_rdd = spark_context.parallelize( test_input_data, 2 )
  assert getNumMatchesPerCountry( test_input_rdd ).collect()[0] == ('Australia', 408 )

The steps in the pytest code,

  • test_input_data - first initialize test data, which can be used for test the functions
  • spark_context - is a reusable fixture for initializing the spark context for all subsequent test cases
  • test_get_num_captains - is th test function for finding out how many captains, captained more than certain number of matches
  • test_get_num_matches_per_country - is the test function for finding out how many total number matches played for each country

Steps to execute the pytest code

  • Copy / Transfer both the py files into a directory e.g. /home/hadoop/lab/programs
  • Ensure that the following environment variables are defined in the system (change the paths as applicable to your system)
- export SPARK_HOME=/home/hadoop/lab/software/spark-1.6.0-bin-hadoop2.6

- export PYTHONPATH=$SPARK_HOME/python/:$SPARK_HOME/python/lib/py4j-0.9-src.zip:$PYTHONPATH


  • Run the following command to execute the script
- python -m pytest -v captains_test.py

Verifying the script output

The script typically shows how many test cases passes and how many failed. Also, in verbose mode (-v), it shows which all test cases passed and which all test cases failed.

A sample output is shown below.

Successful scenario

The results show both the test cases have passed.

Faliure Scenario

The following snapshot shows one of the test cases have failed. It shows which one has failed and the reason of failures. It is showing the actual results was ('Australia', 408), where as the expected results was ('Australia', 409).