Friday, November 18, 2016

Recursive Mean & Standard Deviation in Scala

Target audience: Intermediate
Estimated reading time: 3'

Have you considered the possibility of accelerating statistical computations? We're introducing the concept of tail recursion as a method to enhance the efficiency of calculating mean and standard deviation.


Table of contents
Follow me on LinkedIn
Note: Scala version 2.11.8

Overview

The computation of the mean and standard deviation of a very large data set may cause overflow of the summation of values. Scala tail recursion is a very good alternative to compute mean and standard deviation for data set of unlimited size.

Direct computation

There are many ways to compute the standard deviation through summation. The first mathematical expression consists of the sum the difference between each data point and the mean.
\[\sigma =\sqrt{\frac{\sum_{0}^{n-1}(x-\mu )^{2}}{n}}\]
The second formula allows to update the mean and standard deviation with any new data point (online computation).
\[\sigma =\sqrt{\frac{1}{n}\sum_{0}^{n-1}x^{2}-{\mu ^{2}}}\]
This second approach relies on the computation the sum of square values that can overflow


1
2
3
4
val x = Array[Double]( /* ... */ )
val mean = x.sum/x.length
val stdDev = Math.sqrt((x.map( _ - mean)
    .map(t => t*t).sum)/x.length)

A reduceLeft can be used as an alternative of map{ ... }.sum for the computation of the standard deviation (line 3).

Recursive computation

There is actually no need to compute the sum and the sum of squared values to compute the mean and standard deviation. The mean and standard deviation for n observations can be computed from the mean and standard deviation of n-1 observations.
The recursive formula for the mean is
\[\mu _{n}=\left ( 1-\frac{1}{n} \right )\mu _{n-1}+\frac{x_{n}}{n}\] The recursive formula for the standard deviation is
\[\varrho _{n}=\varrho _{n-1}+(x_{n}-\mu _{n})(x_{n}-\mu _{n-1}) \ \ \ \ \sigma _{n} = \sqrt{\frac{\varrho _{n}}{n}}\] 
Let's implement these two recursive formula in Scala using the tail recursion (line 4).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
def meanStd(x: Array[Double]): (Double, Double) ={
    
  @scala.annotation.tailrec
  def meanStd(
      x: Array[Double], 
      mu: Double, 
      Q: Double, 
      count: Int): (Double, Double) = 
    if (count >= x.length) (mu, Math.sqrt(Q/x.length))
    else {
      val newCount = count +1
      val newMu = x(count)/newCount + mu * (1.0 - 1.0/newCount)
      val newQ = Q + (x(count) - mu)*(x(count) - newMu)
      meanStd(x, newMu, newQ, newCount)
    }
  
  meanStd(x, 0.0, 0.0, 0)
}

This implementation update the mean and the standard deviation for each new data point simultaneously. The recursion exits when all elements have been accounted for (line 9).

References