Friday, December 1, 2023

Explainable ML models with SHAP

Target audience: Beginner
Estimated reading time: 5'

Have you ever faced the need to rationalize the prediction made by one of your models, or to identify which features are crucial? If so, SHAP values and plots are your go-to resources, offering the fundamental structure for an explanation.


Table of contents
       Use case
       Dataset
       Models
       Metrics
       Dependency plot
       Decision plot
       Force plot
Follow me on LinkedIn

What you will learn: How to use SHAP values and plots to identify the most significant features for multi-classification models.

Notes:

Introduction

SHAP (SHapley Additive exPlanations), which is based on the concepts of game theory, is employed to clarify the predictions of machine learning models [ref 1]. This approach evaluates the contribution of each feature to a model's prediction, aiding in pinpointing the key features and understanding their specific effects on the model's results.

The complete description of the theory behind SHAP [ref 2] is beyond the scope of this article but can be summarized as follow:
For M players, S a subset of M players
\[\varphi _{i}= \sum _{S\sqsubseteq M-\left \{ i \right \} }\frac{|S|! (|M|-|S|-1)!}{|M|!}\left ( f(S \cup \left \{ i \right \}) -f(S)) \right )\] where f is the prediction model\[S\sqsubseteq M-\left \{ i \right \}\] is the subset S of players excluding player i

The prediction made by a model, denoted as f, can be expressed as the total of its SHAP values plus a constant base value, as shown in the equation: f(x)=base.value+(SHAP.values)
To begin a global interpretation using SHAP, one should first look at the average absolute SHAP value for every feature across the entire dataset. This approach measures the average impact (whether positive or negative) of each feature's contribution to the predicted air quality index.

Use Case

SHAP values serve various purposes, including:
  • Debugging models to spot biases or anomalies in the data.
  • Assessing feature importance to pinpoint and eliminate features with minimal impact.
  • Providing detailed explanations for individual predictions.
  • Summarizing models using SHAP value summary plots.
  • Detecting biases to determine if specific features have an undue influence on certain groups.
  • Facilitating regulatory approval by elucidating the model's decision-making process.
In this article, our aim is to calculate SHAP values and analyze the significance of each feature in three classification models. These models are used to forecast Air Quality in 138 cities across the Philippines.

Dataset

We used the Air Quality Index (AQI) dataset of 138 Philippine cities weather data, available In Open Weather Map from Kaggle data repository [ref 3].

The 8 features are components that contribute to air pollution such as  Carbon monoxide (CO), Nitrogen monoxide (NO), Nitrogen dioxide (NO2), Ozone (O3), Sulphur dioxide (SO2), Ammonia (NH3), and particulates (PM2.5 and PM10). 
The 5 labels/classes are indexed as Good (1), Fair (2), Moderate (3), Poor (4), Very Poor (5).

SHAP values and plots

First we implement the class SHAPEval to compute the SHAP values and generate Summary, Dependency, Force and Decision plots, given a predictive model, model_prediction [ref 4].
class SHAPEval(object):
  def __init__(self, model_predictor, plot_type: SHAPPlotType):
     self.model_predictor = model_predictor
     self.plot_type = plot_type


  def __call__(self, validation_data: pd.array, column_names: List[AnyStr]) -> NoReturn:
        # 1- Compute SHAP values
    shap_descriptor = shap.KernelExplainer(self.model_predictor, validation_data)
    shap_values = shap_descriptor.shap_values(validation_data)
        
        # 2- Apply specific  plot to validation data and extracted SHAP values
    match self.plot_type:
       case SHAPPlotType.SUMMARY_PLOT:
           shap.summary_plot(shap_values, validation_data, feature_names=column_names)
       
       case SHAPPlotType.PARTIAL_DEPENDENCY_PLOT:
           shap.dependence_plot("o3", shap_values, validation_data, feature_names=column_names)
       
       case SHAPPlotType.FORCE_PLOT:
           data_point_rank = 8
           shap.force_plot(
                    shap_descriptor.expected_value,
                    shap_values[data_point_rank,:],
                    validation_data[data_point_rank,:],
                    feature_names=column_names,
                    matplotlib=True)
        
       case SHAPPlotType.DECISION_PLOT:
           shap.decision_plot(
                    shap_descriptor.expected_value,
                    shap_values,
                    feature_names=column_names,
                    link='logit')
       case _:
           raise Exception(f'Plot type {self.plot_type} is not supported')


The dunder special method, __call__ accepts a test dataset, validation_data, and a list of feature names, column_names, for the following purposes:
  1. To calculate SHAP values using a Kernel Explainer.
  2. To create various SHAP visualizations.
Different types of explainers exist for various models, such as the TreeExplainer for random forests, the SamplingExplainer for models with independent features, or the DeepExplainer for differentiable models [ref 5].

For our purposes, we have chosen the Kernel Explainer. Its approach of employing weighted linear regression to determine the significance of each feature is particularly well-suited for models like logistic regression, support vector machines, and neural networks.

Models

Following this, we use the SHAPEval method on each of the three models. The ModelEval class, designed for evaluating models, has a constructor with four parameters:
  • filename: This refers to the location of the CSV file that holds the Air Quality Index data.
  • dropped_features: A list of features deemed irrelevant, which will be omitted from the training dataset.
  • label: The column that serves as the target for the classification model.
  • val_train_split: This denotes the proportion of samples allocated for validation compared to training.
@dataclass
class TestMetric:
  accuracy: float
  f1: float
  mean_squared_error: float



class ModelEval(object):
  random_state = 5713
   
  def __init__(self,
                 filename: AnyStr,
                 dropped_features: List[AnyStr],
                 label: AnyStr,
                 val_train_split: float):

     def set_label(x: float) -> int:
        return int(x) - 1

     df = pd.read_csv(filename)
        # Drop non features and label columns
     dropped_features.append(label)
     X = df.drop(dropped_features, axis=1)
        
        # Apply standard normalization
     X_scaled = StandardScaler().fit(X).transform(X)
        # Select column containing label
     y = df[label].apply(set_label)
        
         # Train - validation split
     self.feature_names = X.columns.values.tolist()
     self.X_train, self.X_val, self.y_train, self.y_val = \
            train_test_split(X_scaled, y, test_size=val_train_split, random_state=ModelEval.random_state)



   def __call__(self, model_type: ModelType, plot_type: SHAPPlotType) -> TestMetric:
          # Initialize the classification model
      match model_type:
        case ModelType.LOGISTIC_REGRESSION:
            model = LogisticRegression(
                    solver='lbfgs', 
                    max_iter=1000, 
                    penalty='l2', 
                    multi_class='multinomial')

        case ModelType.SVM:
            model = SVC(
                   kernel="rbf", 
                   decision_function_shape='ovo', 
                   random_state=ModelEval.random_state)

        case ModelType.MLP:
            model = MLPClassifier(
                    hidden_layer_sizes=(32, 16),
                    max_iter=500,
                    alpha=0.0001,
                    solver='adam',
                    random_state=ModelEval.random_state)
        case _:
            raise Exception(f'Model name {model_type} is not supported')
             
             # Train the model
      model.fit(self.X_train, self.y_train)
             # Compute SHAP values and selected plots
      shap_eval = SHAPEval(model.predict, plot_type)
      shap_eval(self.X_val,  self.feature_names)
             
             # prediction and quality metrics
      y_predicted = model.predict(self.X_val)
      return TestMetric(
            accuracy_score(self.y_val, y_predicted),
            f1_score(self.y_val, y_predicted, average='weighted'),
            mean_squared_error(self.y_val, y_predicted)
        )


The following code snippet instantiates the ModelEval class to generate a decision plot (SHAPPlotType.DECISION_PLOT) for the logistic regression (ModelType.LOGISTIC_REGRESSION)

test_filename = '../../data/Philippine_Air_Quality.csv'
test_drop_features = ['datetime', 'coord.lon', 'coord.lat', 'extraction_date_time', 'city_name']
test_label = 'main.aqi'
test_size = 0.01

try:
   model_eval = ModelEval(test_filename, test_drop_features, test_label, test_size)
   test_metrics = model_eval(ModelType.LOGISTIC_REGRESSION, SHAPPlotType.DECISION_PLOT)
 
except SHAPException as e:
    print(str(e))
except Exception as e:
    print(str(e))


Evaluation

The three models been evaluated are using Adam optimizer
  • Logistic regression with L-BFGS solving and L2 regularization
  • Support Vector Machine with Adam optimizer, radial basis function kernel function and ovo decision function shape 
  • Multi-layer perceptron with two hidden layers of respective sizes 32, 16 and Adam solver

Metrics

The quality metrics output for the three models are:
ModelAccuracyF1-ScoreMSE
Logistic Regression0.9280.9240.119
Support Vector Machine0.9740.9530.025
Multi-Layer Perceptron0.9920.9890.002


Comparative summary plots

API: shap.summary_plot(shap_values, data, feature_names)

Initially, we calculate and present a summary report detailing the SHAP values for all three models: logistic regression, support vector machine, and multi-layer perceptron. This plot illustrates the positive and negative correlations between the predictors and the target variable. 
The 'dotty' appearance of the plot arises from the inclusion of each data point from the training dataset. By examining the distribution and positioning of the dots across various features, we can assess which features exert the most influence. Some features may demonstrate a uniform effect (indicated by closely grouped dots), whereas others may show more diverse impacts (evidenced by dots that are more widely scattered).

SHAP summary plot for Logistic Regression with 156 samples

SHAP summary plot for Support Vector Machine with 96 samples

SHAP summary plot for Multi-layer Perceptron with 780 samples

The data points in the plot are arranged along the X-axis based on their SHAP values, ranging from -0.6 to 2.2. The thickness of the stack at each SHAP value indicates how many data points have that particular value, representing the density or concentration of the SHAP value. Additionally, the vertical 'feature value' bar is colored to show the actual raw prediction values.

In these plots, the features like o3, pm2_5, and others are ordered from top to bottom according to their average absolute SHAP value.

The consistency of SHAP values across the three models—logistic regression, support vector machine, and multi-layer perceptron—emphasizes the significance of the o3 and pm2_5 components in influencing the predictions. Notably, the Multi-layer perceptron model displays one or two predominant SHAP values for each feature, aligning with its high f1 score as a classifier.

Dependency plot

API:  shap.dependence_plot('o3', shap_values, data, feature_names)

The dependency plot illustrates the impact that one or two variables exert on the predicted result, revealing the nature of the relationship—whether it's linear, monotonic, or more intricate—between the target and the variables. This type of plot is especially useful for understanding models based on ensemble methods and deep learning.

We will proceed to create a SHAP dependence plot for the neural network model, utilizing a dataset of 780 samples.

SHAP dependency between o3 and pm10 components plot for MLP with 780 samples

The x-axis represents the numerical values of the feature o3. The y-axis shows the SHAP values for both o3 and pm10 features. The higher the value, the greater the impact on the prediction.
The high dispersion along the y-axis indicates that there is some dependency between the targeted feature o3 and other features, primarily pm10.

Decision plot

API: shap.decision_plot(expected_value, shap_values, feature_names, link='logit')

SHAP decision plots reveal the process by which complex models make their predictions, essentially illustrating the decision-making mechanism of these models. In these plots, features are ranked in order of their importance, which is calculated based on the observations being plotted.

Each observation's predicted outcome is depicted by a line of a specific color. These lines intersect the x-axis at the top of the plot, at points that correspond to the predicted values for the observations. The predicted value is what determines the color of the line, typically represented on a spectrum.

The plot effectively demonstrates how the contribution of each feature adds up to the final prediction made by the model.


SHAP Decision plot on 156 samples for logistic regression


The dataset's average prediction, also known as the base value, is set at 0.64. The features, such as o3 and others, are organized in a descending order based on their significance. Each line in the plot represents either a test or validation sample and shows the cumulative effect of each feature. A movement towards the right of the base value (0.64) signifies that the feature positively influences the prediction. Conversely, a shift towards the left indicates that the feature negatively affects the prediction.

In the plot, 156 validation samples are illustrated, culminating in four distinct final probability values: 0.43, 0.73, 0.88, and 0.98.

Force Plot

API: shap.force_plot(expected_value, shap_values[index,:], data[index,:], feature_names, matplotlib=True)

For each observation, you can create a sophisticated visualization known as the force plot. In these plots, features are arranged from left to right, with those making a positive impact positioned on the left and those with a negative impact on the right. For the 8th observation, the key features influencing the model's prediction are highlighted in red and blue. Red indicates the features that increased the model's score, while blue denotes the features that decreased the score.

SHAP observation force plot for 8th sample with logistic regression


Each feature's contribution is represented by an arrow, colored to reflect its impact. The size and orientation of these arrows demonstrate both the strength and the nature (positive indicated by red, negative by blue) of each feature's influence on the prediction.

As highlighted in the summary plot, the o3 component emerges as a primary feature, exerting a negative effect on the prediction with a score of -0.746. Conversely, the pm2_5 feature makes a positive contribution, impacting the prediction with a score of 0.246.

Limitations

Despite its usefulness, SHAP comes with certain constraints, including:
  • It demands substantial computational resources, especially for intricate multi-label or multi-class models that use extensive datasets.
  • The computation relies on the assumption of feature independence, particularly in the case of Kernel or Linear SHAP.
  • While SHAP reveals the extent to which a feature influences a prediction, it does not explain how these features collectively contribute to the target variable.
Thank you for reading this article. For more information ...

References



----------------------------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning", Packt Publishing ISBN 978-1-78712-238-3


Wednesday, November 22, 2023

Tracking Storms with Kafka/Spark Streaming

Target audience: Intermediate
Estimated reading time: 4'

“There is peace even in the storm.” Vincent Van Gogh

Certain challenges in our everyday life demand immediate attention, with severe storms and tornadoes being prime examples. This article demonstrates the integration of Kafka event queues with Spark's distributed structured streaming, crafting a powerful and responsive system for predicting potentially life-threatening weather events.


Table of contents
Streaming frameworks
      Apache Kafka
      Apache Spark
      Spark structured streaming
Use case
      Overview
      Data streaming pipeline
Implementation
      Weather data sources
      Weather tracking pipeline
References
Follow me on LinkedIn
Important notes:
  • Software versions utilized in the source code include Scala 2.12.15, Spark 3.4.0, and Kafka 3.4.0. 
  • This article focuses on the data streaming pipeline. Explaining the storm prediction model falls outside the article's range. 
  • You can find the source code on the Github-Streaming library [ref 1]
  • Note that error messages and the validation of method arguments are not included in the provided code examples.

Streaming frameworks

The goal is to develop a service that warns local authorities about severe storms or tornadoes. Given the urgency of analyzing data and suggesting actions, our alert system is designed as a streaming data flow.

The main elements include:
  • A Kafka message queue, where each monitoring device type is allocated a specific topic, enabling the routing of different alerts to the relevant agency.
  • The Spark streaming framework to handle weather data processing and to disseminate forecasts of serious weather disruptions.
  • A predictive model specifically for severe storms and tornadoes.

Apache Kafka 

Apache Kafka is an event streaming platform enabling:
  • Publishing (writing) and subscribing (reading) to event streams, as well as continuously importing/exporting data from various systems.
  • Durable and reliable storage of event streams for any desired duration.
  • Real-time or retrospective processing of event streams.
Kafka is accessible as an open-source library [ref 2] or through commercial cloud services like Confluence [ref 3].

You can start and stop the Kafka service using shell scripts in the following way:

zookeeper-server-stop
kafka-server-stop
sleep 2

zookeeper-server-start $KAFKA_ROOT/kafka/config/zookeeper.properties &
sleep 1
ps -ef | grep zookeeper
kafka-server-start $KAFKA_ROOT/kafka/config/server.properties &
sleep 1
ps -ef | grep kafka

For testing purpose, we deploy Apache Kafka on a local host listening to the default port 9092. Here are some useful commands
To list existing topic
    kafka-topics --bootstrap-server localhost:9092 --list
To create a new topic (i.e. doppler)
    kafka-topics 
        --bootstrap-server localhost:9092 
        --topic doppler 
        --create 
        --replication-factor 1 
        --partitions 2
To list messages or event current queued for a given topic (i.e., weather)
    kafka-console-consumer
        --topic weather 
        --from-beginning 
        --bootstrap-server localhost:9092

Here is an example of libraries required to build a Kafka consumer and producer application.

<scala.version>2.12.15</scala.version>
<kafka.version>3.4.0</kafka.version>

<dependency>
  <groupId>org.apache.kafka</groupId>
  <artifactId>kafka-streams</artifactId>
  <version>${kafka.version}</version>
</dependency>

<dependency>
  <groupId>org.apache.kafka</groupId>
  <artifactId>kafka-streams-scala_2.12</artifactId>
  <version>${kafka.version}</version>
</dependency>


Apache Spark

Apache Spark is a free, open-source framework for cluster computing, specifically designed to process data in real time via distributed computing. Its primary applications include:
  • Analytics: Spark's capability to quickly produce responses allows for interactive data handling, rather than relying solely on predefined queries.
  • Data Integration: Often, the data from various systems is inconsistent and cannot be combined for analysis directly. To obtain consistent data, processes like Extract, Transform, and Load (ETL) are employed. Spark streamlines this ETL process, making it more cost-effective and time-efficient.
  • Streaming: Managing real-time data, such as log files, is challenging. Spark excels in processing these data streams and can identify and block potentially fraudulent activities.
  • Machine Learning: The growing volume of data has made machine learning techniques more viable and accurate. Spark's ability to store data in memory and execute repeated queries swiftly facilitates the use of machine learning algorithms.
Here is an example of libraries required to build a Spark application:

<spark.version>3.3.2</spark.version>
<scala.version>2.12.15</scala.version>

<dependency>
   <groupId>org.apache.spark</groupId>
   <artifactId>spark-core_2.12</artifactId>
   <version>${spark.version}</version>
</dependency>

<dependency>
   <groupId>org.apache.spark</groupId>
   <artifactId>spark-sql_2.12</artifactId>
   <version>${spark.version}</version>
</dependency>

<dependency>
   <groupId>org.apache.spark</groupId>
   <artifactId>spark-streaming-kafka-0-10_2.12</artifactId>
   <version>${spark.version}</version>
</dependency>

Spark structured streaming

Spark Structured Streaming, built atop Spark SQL, is a streaming data engine. It processes data in increments, continually updating outcomes as additional streaming data is received [ref 4]. A Spark Streaming application consists of three primary parts: the source (input), the processing engine (business logic), and the sink (output).

For comprehensive details on setting up, applying, and deploying the Spark Structured Streaming library, visit Boost real-time processing with Spark Structured Streaming.

In our specific scenario, data input is managed via Kafka consumers, and data output is handled through Kafka producers [ref 5].

Use case 

Overview

This use case involves gathering data from weather stations and Doppler radars, then merging these data sources based on location and time stamps. After consolidation, the unified data is sent to a model that forecasts potentially hazardous storms or tornadoes. The resulting predictions are then relayed back to the relevant authorities (such as emergency personnel, newsrooms, law enforcement, etc.) through Kafka.

The two sources of data are:
Weather station to collect temperature, pressure, humidity
Doppler radar for advanced data regarding wind intensity and direction

In practice, meteorologists receive overlapping data features from weather stations and Doppler radars, which they then reconcile and normalize. However, in our streamlined use case, we choose to omit these overlapping features.

Illustration of collection of data from weather stations and Doppler radars


The process for gathering data from weather data tracking devices is not covered in this context, as it is not pertinent to the streaming pipeline.

Data streaming pipeline

The monitoring streaming pipeline is structured into three phases:
  1. Kafka queue.
  2. Spark's distributed structured streams.
  3. A variety of storm and tornado prediction models, developed using the PyTorch library and accessible via REST API.
As noted in the introductory section, the PyTorch-based model is outside this post's focus and is mentioned here only for context.

Storm tracking streaming pipeline

Data gathered from weather stations and Doppler radars is fed into the Spark engine, where both streams are combined and harmonized based on location and timestamp. This unified dataset is then employed for training the model. During the inference phase, predictions are streamed back to Kafka.

Adding an intriguing aspect to the implementation, both weather and Doppler data are ingested from Kafka as batch queries, and the storm forecasts are subsequently streamed back to Kafka.


Implementation

Weather data sources

Weather tracking device
Initially, we need to establish the fundamental characteristics of weather tracking data points: location (longitude and latitude) and time stamp. These are crucial for correlating and synchronizing data from different devices. The details on how these data points are synchronized will be explained in the section Streaming processing: Data aggregation.

trait TrackingData[T <: TrackingData[T]]  {
self =>
   val id: String                 // Identifier for the weather tracking device
   val longitude: Float       // Longitude for the weather tracking device
   val latitude: Float          // Latitude for the weather tracking device
   val timeStamp: String   // Time stamp data is collected
}


Weather station
We collect temperature, pressure and humidity parameters from weather stations given their location (longitude, latitude) and time interval (timeStamp). Therefore a weather station record, WeatherData inherits tracking attributes from TrackingData.

case class WeatherData (
   override val id: String,              // Identifier for the weather station
   override val longitude: Float,    // Longitude for the weather station
   override val latitude: Float,       // Latitude for the weather station
   override val timeStamp: String = System.currentTimeMillis().toString,      
      // Time stamp data is collected
  
   temperature: Float,         // Temperature (Fahrenheit) collected at timeStamp
   pressure: Float,              // Pressure (millibars) collected at timeStamp
   humidity: Float) extends TrackingData[WeatherData]  {  // Humidity (%) collected at timeStamp


   // Random generator for testing
  def rand(rand: Random, alpha: Float): WeatherData = this.copy(
      timeStamp = (timeStamp.toLong + 10000L + rand.nextInt(2000)).toString,
      temperature = temperature*(1 + alpha*rand.nextFloat()),
      pressure = pressure*(1 + alpha*rand.nextFloat()),
      humidity = humidity*(1 + alpha*rand.nextFloat())
   )

   // Encoded weather station data produced to Kafka queue
  override def toString: String =
    s"$id;$longitude;$latitude;$timeStamp;$temperature;$pressure;$humidity"
}

We generate randomly values of some attributes for testing purpose using the formula \[x(1 + \alpha r_{[0,1]})\]  with alpha ~ 0.1 and r be a uniform distribution between 0 and 1.
The string representation with ; delimiter, is used to serialize the tracking data for Kafka.

Doppler radar
We retrieve wind-related measurements (such as wind shear, average wind speed, gust speed, and direction) from the Doppler radar for specific locations and time intervals. Similar to the weather station, the 'DopplerData' class is a sub-class of the 'TrackingData' trait.

case class DopplerData(
  override val id: String,                // Identifier for the Doppler radar
  override val longitude: Float,     // Longitude for the Doppler radar
  override val latitude: Float,        // Latitude for the doppler radar
  override val timeStamp: String, // Time stamp data is collected
  windShear: Boolean,                 // Is it a wind shear?
  windSpeed: Float,                     // Average wind speed
  gustSpeed: Float,                      // Maximum wind speed
  windDirection: Int) extends TrackingData[DopplerData] {

   // Random generator
  def rand(rand: Random, alpha: Float): DopplerData = this.copy(
     timeStamp = (timeStamp.toLong + 10000L + rand.nextInt(2000)).toString,
     windShear = rand.nextBoolean(),
     windSpeed = windSpeed * (1 + alpha * rand.nextFloat()),
     gustSpeed = gustSpeed * (1 + alpha * rand.nextFloat()),
     windDirection = {
        val newDirection = windDirection * (1 + alpha * rand.nextFloat())
        if(newDirection > 360.0) newDirection.toInt%360 else newDirection.toInt
     }
  )

  //  Encoded Doppler radar data produced to Kafka
  override def toString: String = s"$id;$longitude;$latitude;$timeStamp;$windShear;$windSpeed;" +
    s"$gustSpeed;$windDirection"
}


Weather tracking streaming pipeline

Streaming tasks
To enhance understanding, we simplify and segment the processing engine (business logic) into two parts: the aggregation of input data and the actual model prediction.

The streaming pipeline comprises four key tasks:
  • Source: Retrieves data from weather stations and Doppler radars via Kafka queues.
  • Aggregation: Synchronizes and consolidates the data.
  • Storm Prediction: Activates the model to provide potential storm advisories, communicated through a REST API.
  • Sink: Distributes storm advisories to the relevant Kafka topics.
class WeatherTracking(
  inputTopics: Seq[String],
  outputTopics: Seq[String],
  model: Dataset[ModelInputData] => Seq[WeatherAlert])(implicit sparkSession: SparkSession) {

  def execute(): Unit = 
    for {
       (weatherDS, dopplerDS) <- source                                 // Step 1
       consolidatedDataset <- synchronizeSources(weatherDS, dopplerDS)  // Step 2
       predictions <- predict(consolidatedDataset)                   // Step 3
       consolidatedDataset <- sink(predictions)                       // Step 3
    } yield { consolidatedDataset }
  

In every computational task within the streaming pipeline, exceptions are transformed into options. These are then serialized using the Scala for-comprehension statement.

Streaming source
The data is consumed as a data frame df, with key and value defines as string (1). 

Illustration of encoding/decoding weather tracking data

The key for the source is designated as W_${weather station id} for weather data, and D_{Doppler radar id} for Doppler data. The value consists of the encoded tracking data (toString) (2), which is then sorted by data sources (3). Subsequently, the weather tracking data is decoded into instances of Weather and Doppler radar data (4).

def source: Option[(Dataset[WeatherData], Dataset[DopplerData])] = try {
  import sparkSession.implicits._

  val df = sparkSession.read.format("kafka")
       .option("kafka.bootstrap.servers", "localhost:9092")
       .option("subscribe", inputTopics.mkString(","))
       .option("max.poll.interval.ms", 2800)
       .option("fetch.max.bytes", 32768)
.load(). // (1)

  val ds = df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
                  .as[(String, String)].                                                          // (2)

  \// Convert weather data stream into objects
  val weatherDataDS = ds .filter(_._1.head == 'W')                             // (3)
      .map{ case (_, value) => WeatherDataEncoder.decode(value) }  // (4)

  // Convert Doppler radar data stream into objects
  val dopplerDataDS = ds.filter(_._1.head == 'D')                                // (3)
      .map { case (_, value) => DopplerDataEncoder.decode(value) }  // (4)

  Some((weatherDataDS, dopplerDataDS))
} 
catch { case e: Exception => None }


The batching source's implementation for batch queries may include a range of Kafka configuration parameters for consumers, like max.partition.fetch.bytes, connections.max.idle.ms, fetch.max.bytes, or max.poll.interval.ms [ref 6].
Meanwhile, the decode method transforms a semicolon-delimited string into an instance of either WeatherData or DopplerData.
The method readStream would be used for data streamed consumed from Kafka.

Streaming processing: Data aggregation
The goal is to align data from weather stations and Doppler radar within a time frame of approximately ±20 seconds around the same timestamp. This timestamp is categorized in intervals of 20 seconds.

def synchronizeSources(
  weatherDS: Dataset[WeatherData],
  dopplerDS: Dataset[DopplerData]): Option[Dataset[ModelInputData]] = try {
   
  val timedBucketedWeatherDS = weatherDS.map(
      wData =>wData.copy(timeStamp = bucketKey(wData.timeStamp))
  )
  val timedBucketedDopplerDS = dopplerDS.map(
      wData => wData.copy(timeStamp = bucketKey(wData.timeStamp))
  )
      // Performed a fast, presorted join
  val output = sortingJoin[WeatherData, DopplerData](
      timedBucketedWeatherDS,
      tDSKey = "timeStamp",
      timedBucketedDopplerDS,
      uDSKey = "timeStamp"
  ).map {
    case (weatherData, dopplerData) =>
        ModelInputData(
          weatherData.timeStamp,
          weatherData.temperature,
          weatherData.pressure,
          weatherData.humidity,
          dopplerData.windShear,
          dopplerData.windSpeed,
          dopplerData.gustSpeed,
          dopplerData.windDirection
        )
   }
   Some(output)
} catch { case e: Exception => None }


The sortingJoin is a parameterized method that pre-sort data in each partition prior to joining the two datasets. The implementation is available at Github - Spark fast join implementation.
The data class for the input to the model in described in the appendix.

Streaming sink
Finally, the sink encodes the weather alert generated by the storm predictor. Contrary to the streaming source, the data is produced to Kafka topic as a data stream.

def sink(stormPredictions: Dataset[StormPrediction]): Option[Dataset[String]] = try {

   // Encode dataset of weather alerts to be produced to Kafka topic
  val encStormPredictions = stormPredictions.map(_.toString)

  encStormPredictions
        .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
        .writeStream
        .format("kafka")
        .option("kafka.bootstrap.servers", "localhost:9092")
        .option("subscribe", outputTopics.mkString(",")
  ).start()

  Some(encStormPredictions)
} catch { case e: Exception => None }

The implementation of the sink can list various Kafka configuration parameters for producer such as buffer.memory, retries, batch.size, max.request.size or linger.ms [ref  7]. The method write should be used in the case the data is to be produced in batch to Kafka.

Thank you for reading this article. For more information ...

References


Appendix

Storm prediction input
case class ModelInputData(
   timeStamp: String,       // Time stamp for the new consolidated data, input to model
   temperature: Float,      // Temperature in Fahrenheit
   pressure: Float,           // Barometric pressure in millibars
   humidity: Float,           // Humidity in percentage
   windShear: Boolean,  // Boolean flag to specify if this is a wind shear
   windSpeed: Float,      // Average speed for the wind (miles/hour)
   gustSpeed: Float,       // Maximum speed for the wind (miles/hour)
   windDirection: Float   // Direction of the wind [0, 360] degrees
)

Storm prediction output
case class StormPrediction(
   id: String,                        // Identifier of the alert/message 
   intensity: Int,                  // Intensity of storm or Tornado [1, to 5]
   probability: Float,           // Probability of a storm of a given intensity develop
   timeStamp: String,         // Time stamp of the alert or prediction
   modelInputData: ModelInputData,// Weather and Doppler radar data used to generate/predict the alert
   cellArea: CellArea          // Area covered by the alert/prediction
)


---------------------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning" Packt Publishing ISBN 978-1-78712-238-3