Showing posts with label tail recursion. Show all posts
Showing posts with label tail recursion. Show all posts

Wednesday, December 13, 2023

Distributed Recursive Kalman Filter on Large Datasets

Target audience: Expert
Estimated reading time: 7'

Have you ever pondered the best way to efficiently operate a Kalman filter with a significantly large set of measurements? The predict and update/correct cycle of Kalman filtering can be computationally demanding when dealing with millions of records. Utilizing recursion and Apache Spark might be the solution you're looking for.

Table of contents

       Overview
       Noise
       Prediction
       Tail recursion
       Kalman predictor
       Recursive method
Follow me on LinkedIn

What you will learn: How to implement the Kalman Filter on a vast set of measurements utilizing multiple sampling techniques, recursion, and Apache Spark.

Notes:
  • Environments: Apache Spark 3.4.0, Scala 2.12.11
  • Source code available at GitHub github.com/patnicolas/kalman
  • To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statements.
  • A simple Scala implementation of Kalman Filter is described in Discrete Kalman Predictor in Scala


The Challenge

The Internet of Things (IoT) produces an immense volume of data points. Implementing Kalman filtering on these measurements can require significant computational resources. A potential approach to manage this is by sampling the data to approximate its distribution. However, it's important to note that there's no assurance that the chosen sampling technique will maintain the original distribution of the raw data. 


Employing a combination of different types of samplers could help mitigate the effects of a reduced dataset on the precision of the Kalman filter's predictions.

Illustration of sampling and distribution of Kalman predictions using Spark

Kalman filter

First let's look at the Kalman optimal filter, and its implementation on Spark using fast recursion. Renowned in the realms of signal processing and statistical analysis, the Kalman filter serves as a potent tool to measure or estimate noise arising from processes and the disturbances introduced by measurement instruments [ref 1]. 

Overview

Kalman filter serves as an ideal estimator, determining parameters from imprecise and indirect measurements. Its goal is to reduce the mean square error associated with the model's parameters. Being recursive in nature, this algorithm is suited for real-time signal processing. However, one notable constraint of the Kalman filter is its need for the process to be linear, represented as y = a.f(x) + b.g(x) + .... 

Illustration of the Kalman prediction process

Noise

The state of a deterministic time linear dynamic system is the smallest vector that summarizes the past of the system in full and allow a theoretical prediction of the future behavior, in the absence of noise.
There are two source of noise:
  • Noise generated by the process following a normal distribution with zero mean and a Q variance, N(0,Q)
  • Noise generated by the measurement devices that also follows a Normal distribution N(0, R)
Based on an observation or measurement zn, the true state xn is forecasted using the previous state xn1 and the prior measurement zn1 through a process that alternates between prediction and updating, as illustrated in the diagram below:

Illustration of the sequence of operations in Kalman Filter

Prediction

After initialization, the Kalman Filter forecasts the system's state for the upcoming step and estimates the uncertainty associated with this prediction.

Considering An as the state transition model applied to the state xn1, Bn as the control input model applied to the control vector un if it exists, Qn as the covariance of the process noise, and Pn as the error covariance matrix, the forecasted state xis \[\begin{matrix} \widetilde{x}_{n/n-1}=A_{n}.\widetilde{x}_{n-1/n-1} + B_{n}.u_{n} \ \ (1)\\ P_{n/n-1}=A_{n}.P_{n-1/n-1}.A_{n}^{T}+Q_{n} \ \ (2) \end{matrix}\]

Measurement update & optimal gain

Upon receiving a measurement, the Kalman Filter adjusts or corrects the forecast and uncertainty of the current state. It then proceeds to predict future states, continuing this process. 
Thus, with a measurement zn1, a state xn1, and the innovation Sn, the Kalman Gain  and the error covariance are calculated according. \[\begin{matrix} S_{n}=H.P_{n/n-1}.H^{T} +R_{n} \ \ \ \ \  (3)\ \ \ \ \ \ \ \ \ \ \ \ \ \\ G_{n} = \frac{1}{S_{n}}.P_{n/n-1}.H^{T}\ \ \ \ \  (4) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \  \ \ \ \ \ \ \\ \widetilde{x}_{n/n} = \widetilde{x}_{n/n-1}+G_{n}(z_{n}-H.\widetilde{x}_{n/n-1}) \ \ \ (5) \\ g_{n}=I - G_{n}.H \ \ \ \ \ (6) \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \ \\ P_{n,n}= g_{n}.P_{n/n-1}.g_{n}^{T}+G_{n}.R_{n}.G_{n}^{T} \ \ \ \ (7) \ \ \end{matrix}\]

Apache Spark

Apache Spark is a free, open-source framework for cluster computing, specifically designed to process data in real time via distributed computing [ref 2]. Its primary applications include:
  • Analytics: Spark's capability to quickly produce responses allows for interactive data handling, rather than relying solely on predefined queries.
  • Data Integration: Often, the data from various systems is inconsistent and cannot be combined for analysis directly. To obtain consistent data, processes like Extract, Transform, and Load (ETL) are employed. Spark streamlines this ETL process, making it more cost-effective and time-efficient.
  • Streaming: Managing real-time data, such as log files, is challenging. Spark excels in processing these data streams and can identify and block potentially fraudulent activities.
  • Machine Learning: The growing volume of data has made machine learning techniques more viable and accurate. Spark's ability to store data in memory and execute repeated queries swiftly facilitates the use of machine learning algorithms.

Implementation

Tail recursion

The Kalman filter operates as a recursive estimator, implying that it computes the current state's estimate using just the previous time step's estimated state and the current measurement. Consequently, its reliance on recursion for implementation is highly logical.
A recursive function is said to be tail-recursive if the recursive call is the last thing done in the function before returning. A tail-recursive is significantly faster than non-tail recursion because it does not require stack frames [ref 3].

For measurements zn the recursion process one measurement at the time:
    execute(zn, predictions):
        IF zn last measurement:
            return predictions
        ELSE:
           predict
           predicted_z = update
           predictions.add(predicted_z )
           execute(zn+1, prediction)


Process & measurement noises

We use Spark linear algebra types, DenseVector and DenseMatrix to represent vector and matrices associated with the Kalman parameters [ref 4].

case class KalmanNoise(
  qNoise: Double,          // Standard deviation of process noise
  rNoise: Double,          // Standard deviation of the measurement noise
  length: Int,                  // Number of features or rows associated with the noise
  noiseGen: Double => Double) {  // Distribution function for generating noise
 
  final def processNoise: DenseMatrix =
    new DenseMatrix(length, length, randMatrix(length, qNoise, noiseGen).flatten)
  
  final def measureNoise: DenseMatrix =
    new DenseMatrix(length, length, Array.fill(length*length)(noiseGen(qNoise)))
}

The KalmanNoise class is designed to not specifically assume that the process and measurement noises are Gaussian (white noise) to allow for a more generalized approach. However, in practical applications, these noise components often adhere to a normal distribution with a mean of 0 and standard deviations denoted by qNoise for process noise and rNoise for measurement noise, respectively.

Kalman parameters

To enhance clarity, the source code adopts the same naming convention for variables as those found in the equations for Kalman prediction and update. Adhering to object-oriented design principles, the KalmanParameters class specifies how to calculate the residuals, innovation, and the optimal Kalman gain using these parameters.

case class KalmanParameters(
  A: DenseMatrix,     // State transition dense matrix
  B: DenseMatrix,     // Optional Control dense matrix
  H: DenseMatrix,     // Measurement dense matrix
  P: DenseMatrix,     // Error covariance dense matrix
  x: DenseVector) {   // Estimated value dense vector

  private def HTranspose: DenseMatrix = H.transpose
  def ATranspose: DenseMatrix = A.transpose

  /**. Compute the difference residual = z - H.x. */
  def residuals(z: DenseVector): DenseVector = subtract(z, H.multiply(x))

  /** Compute S = H.P.H_transpose + measurement noise.  equation 3 */
  def innovation(measureNoise: DenseMatrix): DenseMatrix =
     add(H.multiply(P).multiply(HTranspose), measureNoise)

  /**. Compute the Kalman gain G = P * H_transpose/S.  equation 4*/
  def gain(S: DenseMatrix): DenseMatrix = {
     val invStateMatrix = inv(S)
     P.multiply(HTranspose).multiply(invStateMatrix)
  }
}


Kalman predictor

The RKalman class, designed for implementing the recursive sequence of predict-update, accepts two parameters:
  • Initial parameters for the Kalman filter, named initialParams.
  • The implicit process and measurement noises, referred to as kalmanNoise.

The predict method executes the two predictive equations, labeled (1) and (2). It allows for an optional control input variable U as its argument. The update method then proceeds to refresh the Kalman parameters, specifically x and the error covariance matrix P, before calculating the optimal Kalman gain.


class RKalman(initialParams: KalmanParameters)(implicit kalmanNoise: KalmanNoise){
  private[this] var kalmanParams: KalmanParameters = initialParams


  def apply(z: Array[DenseVector]): List[DenseVector] = { .. }

  /**   x(t+1) = A.x(t) + B.u(t) + Q
   *    P(t+1) = A.P(t)A^T^ + Q    */
  def predict(U: Option[DenseVector] = None): DenseVector = {
    // Compute the first part of the state equation S = A.x
    val newX = kalmanParams.A.multiply(kalmanParams.x) // Equation (1)
    
    // Add the control matrix if u is provided  S += B.u
    val correctedX = U.map(u => kalmanParams.B.multiply(u)).getOrElse(newX) 
    
    // Update the error covariance matrix P as P(t+1) = A.P(t).A_transpose + Q
    val newP = add(             // Equation (2)
      kalmanParams.A.multiply(kalmanParams.P).multiply(kalmanParams.ATranspose),
      kalmanNoise.processNoise
    )
    // Update the kalman parameters
    kalmanParams = kalmanParams.copy(x = correctedX, P = newP)
    kalmanParams.x
  }

  /** Implement the update of the state x and error covariance P given the 
   *  measurement z and compute the Kalman gain */
  def update(z: DenseVector): DenseMatrix = {
    val y = kalmanParams.residuals(z)
    val S = kalmanParams.innovation(kalmanNoise.measureNoise) // Equation (3)
    val kalmanGain: DenseMatrix = kalmanParams.gain(S)          // Equation (4)
    
    val nextX = add(kalmanParams.x, kalmanGain.multiply(y))     // Equation (5)
    kalmanParams = kalmanParams.copy(x = nextX)
    
    val nextP = updateErrorCovariance(kalmanGain)               // Equation (7)
    kalmanParams = kalmanParams.copy(P = nextP)
    kalmanGain
  }
}

Recursive method

Let's apply tail recursion to an array of measurements specified as a DenseVector type. The recursion, wrapped in method recurse, stops after processing the final measurement, zlast. If not, the Kalman filter's current parameters are subjected to a predict-update cycle. This cycle is executed for the measurement at zindex, after which the method recursively invokes itself for the subsequent measurement at zindex+1.

def recurse(z: Array[DenseVector]): List[DenseVector] = {

  @tailrec
  def execute(
    z: Array[DenseVector],
    index: Int,
    predictions: ListBuffer[DenseVector]): List[DenseVector] = {
      if (index >= z.length)  // Criteria to end recursion
        predictions.toList
      else {
        val nextX = predict()
        val estimatedZ: DenseVector = kalmanParams.H.multiply(nextX)
        predictions.append(estimatedZ)
        update(z(index))
            
         // Execute the next measurement points
        execute(z, index + 1, predictions)
     }
  }

  execute(z, 0, ListBuffer[DenseVector]())
}

Sampling-based estimator

We are prepared to utilize the recursive method, apply(), to sample the raw measurement set z using various sampling algorithms and process each resulting sample in parallel with Apache Spark, where each sample is allocated to a partition. In our approach, we employ a random uniform sampler with varying parameters. 

For the current measurement zn, the subsequent value to be gathered is \[z_{next} \leftarrow z_{n+a+rand[0, b]}\]After creation, the samples are processed in parallel using Spark's mapPartitions methods. Each worker node creates an instance of RKalman and makes a recursive call using the method, recurse.

def apply(
  kalmanParams: KalmanParameters,// Kalman parameters used by the filter/predictor
  z: Array[DenseVector],         // Series of observed measurements as dense vector
  numSamplingMethods: Int,  // Number of samples to be processed concurrently
  minSamplingInterval: Int,     // Minimum number of samples to be ignored between sampling
  samplingInterval: Int            // Range of random sampling
)(implicit sparkSession: SparkSession): Seq[Seq[DenseVector]] = {

  // Generate the various samples from the large set of raw measurements
  val samples: Seq[Seq[DenseVector]] = (0 until numSamplingMethods).map(
   _ => sampling(z, minSamplingInterval, samplingInterval)
  )

  // Distribute the Kalman prediction-correction cycle over Spark workers
  // by assigning a partition to a Kalman process and sampled measurement.
  val samplesDS = samples.toDS()
    val predictionsDS = samplesDS.mapPartitions(
     (sampleIterator: Iterator[Seq[DenseVector]]) => {
       val acc = ListBuffer[Seq[DenseVector]]()

       while(sampleIterator.hasNext) {
         implicit val kalmanNoise: KalmanNoise = KalmanNoise(kalmanParams.A.numRows)
          
         val rKalman = new RKalman(kalmanParams)
         val z = sampleIterator.next()
         acc.append(rKalman.recurse(z))
      }
      acc.iterator
    }
  ).persist()
 
  predictionsDS.collect()
}

The code for our sampling method is listed in the appendix.


Use case

Our implementation is evaluated with measurement of the velocity of a rocket given a constant acceleration. The error covariance, P is initialized with a mean value of 0.5.

def velocity(x: Array[Double]): KalmanParameters = {
   val acceleration = 0.0167

   val A = Array[Array[Double]](            //  State transition dense matrix
     Array[Double](1.0, acceleration), Array[Double](0.0, 1.0)
   )
   val H = Array[Array[Double]](          // Measurement dense matrix
     Array[Double](1.0, 0.0), Array[Double](0.0, 0.0)
   )
   val P = Array[Array[Double]](         // Error covariance dense matrix
     Array[Double](0.5, 0.0), Array[Double](0.0, 0.5)
   )
   KalmanParameters(A, None, H, Some(P), x)
}

We simulate the raw measurements for the velocity using the simple formula \[\begin{vmatrix} v(x) \\ 1 \end{vmatrix} \ \ with\ \ v(x)=0.01.x^{2}+\frac{0.002}{x+2}+N(0, 0.2)\].
 // Simulated velocity measurements
val z = Array.tabulate(2000)(n =>
  Array[Double](n * n * 0.01 + 0.002 / (n + 2) + normalRandomValue(0.2), 1.0)
)
val zVec = z.map(new DenseVector(_))
    
  // Initial velocity and acceleration
val xInitial = Array[Double](0.001, acceleration)
val recursiveKalman = new RKalman(velocity(xInitial))
val predictedStates = recursiveKalman.apply(zVec)

The output of the Kalman predictor 𝑥 ̃is compared with the original measurements of velocity z and the output of the moving average with a window of 5 seconds, 𝑧 ̃.
The Kalman predictor uses a uniform random sampling rate over a window [4, 8].
Measurement sample:  ..,zn, zn+4+rand[0, 4], ....

The following plot illustrates the behavior of the Kalman predictor and non-weighted moving average for the first 95 seconds.

Raw measurements, 5 sec window Moving average, and Kalman predictions plot

The following plot compares the output of the Kalman predictor using a uniform random 
sampling rate over a window [4, 8] and [6, 12].

Impact of sampling method/interval on the output of Kalman predictor

References

[1] Introduction to Kalman Filter University of North Carolina G. Welsh, G. Bishop
[2] Apache Spark

-------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning", Packt Publishing ISBN 978-1-78712-238-3

Appendix

This method updates the error covariance matrix P using the Kalman filter as defined in the equation 7.

 def updateErrorCovariance(kalmanGain: DenseMatrix): DenseMatrix = {
    val identity = DenseMatrix.eye(kalmanGain.numRows)
    val kHP = subtract(identity, kalmanGain.multiply(kalmanParams.H))
                                 .multiply(kalmanParams.P)
    val kH = subtract(identity, 
                                kalmanGain.multiply(kalmanParams.H).transpose)
    val kR = (kalmanGain.multiply(kalmanNoise.measureNoise))
                     .multiply(kalmanGain.transpose)
    
    add(kHP.multiply(kH), kR)
  }
}

The sampling method extract a subset of the original raw measurement using the formula \[z_{newIndex} \leftarrow z_{currentIndex+a+rand[0, b]}\]
def sampling(
  measurements: Array[DenseVector],    // Raw measurements (z)
  minSamplingInterval: Int,                      // Minimum sampling interval
  samplingInterval: Int): Seq[DenseVector] = {

  val rand = new Random(42L)

   // Next data: z(n+1) = z(n + minSamplingInterval + rand[0, sampling interval]
  val interval = rand.nextInt(samplingInterval) + minSamplingInterval
  
  measurements.indices.foldLeft(ListBuffer[DenseVector]())(
   (acc, index) => {
     if (index % interval == 0)
       acc.append(measurements(index))
     acc
   }
 )
}