Testing a Spark Streamlet

A testkit is provided to make it easier to write unit tests for Spark streamlets. The unit tests are meant to facilitate local testing of streamlets.

Basic flow of testkit APIs

Here’s the basic flow that you need to follow when writing tests using the testkit:

  1. Extend the test class with the SparkScalaTestSupport trait. This trait provides the basic functionalities of managing the SparkSession, basic initialization and cleanups and the core APIs of the testkit.

  2. Create a Spark streamlet testkit instance.

  3. Create the Spark streamlet that needs to be tested.

  4. Setup inlet taps that tap the inlet ports of the streamlet.

  5. Setup outlet taps for outlet ports.

  6. Push data into inlet ports.

  7. Run the streamlet using the testkit and the setup inlet taps and outlet taps.

  8. Write assertions to ensure that the expected results match the actual ones.

Details of the workflow

Let’s consider an example where we would like to write unit tests for testing a SparkStreamlet that reads data from an inlet, does some processing and writes processed data to an outlet. We will follow the steps that we outlined in the last section. We will use ScalaTest as the testing library.

Setting up a sample SparkStreamlet

Here is a list of imports needed for writing the test suite.

import scala.collection.immutable.Seq
import scala.concurrent.duration._

import org.apache.spark.sql.streaming.OutputMode

import cloudflow.streamlets.StreamletShape
import cloudflow.streamlets.avro._
import cloudflow.spark.avro._
import cloudflow.spark.testkit._
import cloudflow.spark.sql.SQLImplicits._

SparkStreamlet is an abstract class. Let’s set up a concrete instance that we would like to test. For more details on how to implement a Spark streamlet, please refer to Building a Spark streamlet.

// create Spark Streamlet
val processor = new SparkStreamlet {
  val in = AvroInlet[Data]("in")
  val out = AvroOutlet[Simple]("out", _.name)
  val shape = StreamletShape(in, out)

  override def createLogic() = new SparkStreamletLogic {
    override def buildStreamingQueries = {
      val dataset = readStream(in)
      val outStream = dataset.select($"name").as[Simple]
      val query = writeStream(outStream, out, OutputMode.Append)
      Seq(query)
    }
  }
}

The unit test

Here’s how we would write a unit test using ScalaTest. The various logical steps of the test are annotated with inline comments explaining the rationale behind the step.

class SparkProcessorSpec extends SparkScalaTestSupport { // 1. Extend SparkScalaTestSupport

  "SparkProcessor" should {

    // 2. Initialize the testkit
    val testkit = SparkStreamletTestkit(session)

    "process streaming data" in {

      // 3. create Spark streamlet
      val processor = new SparkStreamlet {
        val in = AvroInlet[Data]("in")
        val out = AvroOutlet[Simple]("out", _.name)
        val shape = StreamletShape(in, out)

        override def createLogic() = new SparkStreamletLogic {
          override def buildStreamingQueries = {
            val dataset = readStream(in)
            val outStream = dataset.select($"name").as[Simple]
            val query = writeStream(outStream, out, OutputMode.Append)
            Seq(query)
          }
        }
      }

      // 4. setup inlet tap on inlet port
      val in: SparkInletTap[Data] = testkit.inletAsTap[Data](processor.in)

      // 5. setup outlet tap on outlet port
      val out: SparkOutletTap[Simple] = testkit.outletAsTap[Simple](processor.out)

      // 6. build data and send to inlet tap
      val data = (1 to 10).map(i ⇒ Data(i, s"name$i"))
      in.addData(data)

      // 7. Run the streamlet using the testkit and the setup inlet taps and outlet probes
      testkit.run(processor, Seq(in), Seq(out), 2.seconds)

      // get data from outlet tap
      val results = out.asCollection(session)

      // 8. Assert that actual matches expectation
      results should contain(Simple("name1"))
    }
  }
}

The SparkScalaTestSupport trait

This provides session management and needs to be mixed in with the main test class. This trait provides the following functionalities:

  1. Manage a SparkSession for all tests, initialized when the test class initialize.

  2. Cleanup the session using afterAll. If you want custom logic for cleanups, override the afterAll method and call super.afterAll() before adding your custom logic.

The SparkStreamletTestkit class

  1. Provide core APIs like inletAsTap, outletAsTap, run.

  2. Support for adding values for configuration parameters.

Special note on aggregation query

There may be situations where the Spark query that you are testing involves aggregation operators. The testkit gives some special considerations to writing tests for such queries. Here’s an example of a Spark Streamlet that involves aggregation operators:

/*
 * Copyright (C) 2016-2020 Lightbend Inc. <https://www.lightbend.com>
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package carly.aggregator

import org.apache.spark.sql.Dataset
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._

import cloudflow.streamlets._
import cloudflow.streamlets.avro._
import cloudflow.spark.{ SparkStreamlet, SparkStreamletLogic }
import org.apache.spark.sql.streaming.OutputMode
import cloudflow.spark.sql.SQLImplicits._
import org.apache.log4j.{ Level, Logger }

import carly.data._
class CallStatsAggregator extends SparkStreamlet {

  val rootLogger = Logger.getRootLogger()
  rootLogger.setLevel(Level.ERROR)

  //tag::docs-schemaAware-example[]
  val in    = AvroInlet[CallRecord]("in")
  val out   = AvroOutlet[AggregatedCallStats]("out", _.startTime.toString)
  val shape = StreamletShape(in, out)
  //end::docs-schemaAware-example[]

  val GroupByWindow = DurationConfigParameter("group-by-window", "Window duration for the moving average computation", Some("1 minute"))

  val Watermark = DurationConfigParameter("watermark", "Late events watermark duration: how long to wait for late events", Some("1 minute"))

  override def configParameters = Vector(GroupByWindow, Watermark)
  override def createLogic = new SparkStreamletLogic {
    val watermark     = context.streamletConfig.getDuration(Watermark.key)
    val groupByWindow = context.streamletConfig.getDuration(GroupByWindow.key)

    //tag::docs-aggregationQuery-example[]
    override def buildStreamingQueries = {
      val dataset   = readStream(in)
      val outStream = process(dataset)
      writeStream(outStream, out, OutputMode.Update).toQueryExecution
    }

    private def process(inDataset: Dataset[CallRecord]): Dataset[AggregatedCallStats] = {
      val query =
        inDataset
          .withColumn("ts", $"timestamp".cast(TimestampType))
          .withWatermark("ts", s"${watermark.toMillis()} milliseconds")
          .groupBy(window($"ts", s"${groupByWindow.toMillis()} milliseconds"))
          .agg(avg($"duration").as("avgCallDuration"), sum($"duration").as("totalCallDuration"))
          .withColumn("windowDuration", $"window.end".cast(LongType) - $"window.start".cast(LongType))

      query
        .select($"window.start".cast(LongType).as("startTime"), $"windowDuration", $"avgCallDuration", $"totalCallDuration")
        .as[AggregatedCallStats]
    }
    //end::docs-aggregationQuery-example[]
  }
}

The above query involves aggregation operations like groupBy over window and subsequent averaging of call durations tracked by the query. Such stateful streaming queries use a StateStore and need to be run with OutputMode.Update (recommended) or OutputMode.Complete when streaming data appears in the incoming DataSet. OutputMode.Update updates the aggregate with incoming data and hence is recommended for writing tests involving aggregation queries. On the other hand OutputMode.Complete does not drop old aggregation state since by definition this mode preserves all data in the Result Table.

For queries that do not involve aggregation, only new data need to be written and the query needs to be run with OutputMode.Append. This is the default behavior that the testkit sets for the user.

In order to have the query run with OutputMode.Update, the user needs to pass in this argument explicitly when setting up the outlet probe:

// setup outlet tap on outlet port
val out = testkit.outletAsTap[AggregatedCallStats](aggregator.shape.outlet, OutputMode.Update)

For more details on OutputMode, have a look at the relevant documentation on Spark.

Also since the stateful queries use StateStore, the latter needs to be explicitly stopped when the test suite ends:

override def afterAll(): Unit = {
  super.afterAll()
  StateStore.stop() // stop the state store maintenance thread and unload store providers
}