Target audience: Intermediate
Estimated reading time: 3'
Have you considered the possibility of accelerating statistical computations? We're introducing the concept of tail recursion as a method to enhance the efficiency of calculating mean and standard deviation.
Table of contents
Note: Scala version 2.11.8
Overview
The computation of the mean and standard deviation of a very large data set may cause overflow of the summation of values. Scala tail recursion is a very good alternative to compute mean and standard deviation for data set of unlimited size.
Direct computation
There are many ways to compute the standard deviation through summation. The first mathematical expression consists of the sum the difference between each data point and the mean.
\[\sigma =\sqrt{\frac{\sum_{0}^{n-1}(x-\mu )^{2}}{n}}\]
The second formula allows to update the mean and standard deviation with any new data point (online computation).
\[\sigma =\sqrt{\frac{1}{n}\sum_{0}^{n-1}x^{2}-{\mu ^{2}}}\]
This second approach relies on the computation the sum of square values that can overflow
\[\sigma =\sqrt{\frac{\sum_{0}^{n-1}(x-\mu )^{2}}{n}}\]
The second formula allows to update the mean and standard deviation with any new data point (online computation).
\[\sigma =\sqrt{\frac{1}{n}\sum_{0}^{n-1}x^{2}-{\mu ^{2}}}\]
This second approach relies on the computation the sum of square values that can overflow
1
2
3
4
| val x = Array[Double]( /* ... */ )
val mean = x.sum/x.length
val stdDev = Math.sqrt((x.map( _ - mean)
.map(t => t*t).sum)/x.length)
|
A reduceLeft can be used as an alternative of map{ ... }.sum for the computation of the standard deviation (line 3).
Recursive computation
There is actually no need to compute the sum and the sum of squared values to compute the mean and standard deviation. The mean and standard deviation for n observations can be computed from the mean and standard deviation of n-1 observations.
The recursive formula for the mean is
\[\mu _{n}=\left ( 1-\frac{1}{n} \right )\mu _{n-1}+\frac{x_{n}}{n}\]
The recursive formula for the standard deviation is
\[\varrho _{n}=\varrho _{n-1}+(x_{n}-\mu _{n})(x_{n}-\mu _{n-1}) \ \ \ \ \sigma _{n} = \sqrt{\frac{\varrho _{n}}{n}}\]
Let's implement these two recursive formula in Scala using the tail recursion (line 4).
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
| def meanStd(x: Array[Double]): (Double, Double) ={
@scala.annotation.tailrec
def meanStd(
x: Array[Double],
mu: Double,
Q: Double,
count: Int): (Double, Double) =
if (count >= x.length) (mu, Math.sqrt(Q/x.length))
else {
val newCount = count +1
val newMu = x(count)/newCount + mu * (1.0 - 1.0/newCount)
val newQ = Q + (x(count) - mu)*(x(count) - newMu)
meanStd(x, newMu, newQ, newCount)
}
meanStd(x, 0.0, 0.0, 0)
}
|
This implementation update the mean and the standard deviation for each new data point simultaneously. The recursion exits when all elements have been accounted for (line 9).
References
- Programming in Scala - 3rd edition M. Odersky, L. Spoon, B. Venners
- github.com/patnicolas
No comments:
Post a Comment
Note: Only a member of this blog may post a comment.