Target audience: Beginner
Estimated reading time: 4'
Introduction to a simple function approximation algorithm using a dynamically resizable histogram in Scala.
Overview
A typical function approximation consists of finding a model that fit a given data set. Let's consider the following data set {x, y} for which a simple model f: y = f(x) has to be approximated.
The black dots represent the data set and the red line the model y = f(x)
There are multiple approaches to approximate a model or a function to fit a given data set. The list includes
The black dots represent the data set and the red line the model y = f(x)
There are multiple approaches to approximate a model or a function to fit a given data set. The list includes
- Splines
- Least square regression
- Levenberg-Marquardt
- K-nearest neighbors
- Histogram
- Polynomial interpolation
- Logistic regression
- Neural Networks
- Tensor products
- ... and many more
Histogram class
An histogram consists of array or sequence of bins. A bin is defined with three parameters:
- _x Center of the bin
- _y Average value for the bin
- count Number of data points in the bin
Let's look at an implementation of the Bin class. The constructor takes two values:
- _x mid point of the bin (or bucket)
- _y current frequency for this bin
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 | class Bin(
var _x: Double,
var _y: Double =0.0) extends Serializable {
var count: Int = 0
def += (y: Double): Unit = {
val newCount = count + 1
_y = if(count == 0) y else (_y*count + y)/newCount
count = newCount
}
def + (next: Bin): this.type = {
val newCount = count + next.count
if(newCount > 0)
_y = (_y*count + next._y*next.count)/newCount
this
}
def + (next: Array[Bin]): this.type = {
val newCount = next.aggregate(count)(
(s,nxt) => s +nxt.count, _ + _
)
if( newCount > 0) {
_y = next./:(_y*count)(
(s, nxt) => s + nxt._y*nxt.count)/newCount
count = newCount
}
this
}
}
|
The method += (y: Double): Unit adds a new value y for this bin (line 6). It recomputes the average frequency _y (line 8). The method + (next: Bin): this.type (line 12) adds the content of another bin, next to this bin (line 15). Finally, the method + (next: Array[Bin]): this.type (line 19) merges an array of bins into this bin (line 23-26).
Next let's create a class, Histogram (line 1) to manage the array of bins. The constructor for the histogram has four parameters
Next let's create a class, Histogram (line 1) to manage the array of bins. The constructor for the histogram has four parameters
- maxNumBins maximum number of bin (line 2)
- min expected minimum value in the data set (line 3)
- max expected maximum value in the data set (line 4)
- optional smoothing weights (line 5)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28 | final class Histogram(
maxNumBins: Long,
var min: Double = -1.0,
var max: Double = 1.0,
weights: Array[Float] = Array.empty[Float]) {
val initNumBins: Int = (maxNumBins>>RADIX).toInt
private var step = (max-min)/initNumBins
private[this] var values = Range(0, initNumBins-1)
.scanLeft(min + 0.5*step)((s, n) => s + step)
./:(new ArrayBuffer[Bin])(
(vals, x) => vals += (new Bin(x))
)
def train(x: Double, y: Double): Unit = {
<<(x)
values(index(x)) += y
if( values.size >= maxNumBins) >>
}
final def predict(x: Double): Double = {
if( x < min || x > max) Double.NaN
else if(weights.length > 0) smoothPredict(x)
else values(index(x))._y
}
// ... ancillary methods
}
|
Implementation
The two main methods are
- train (line 16) which updates a model (histogram) with each new data point from the training set. The histogram expands when a new data point exceeds the current boundary (min, max) of the histogram (line 19).
- predict (line 22) which predicts the value y of the new observation x. The prediction relies on an interpolation (weighted moving average) (line 24) in the case the user specifies an array of weights in the histogram constructor.
1
2
3
4
5
6 | private def index(x: Double): Int = {
val idx = ((x - min)/step).floor.toInt
if( idx < 1) 0
else if (idx >= values.size) values.size -1
else idx
}
|
This implementation of the dynamically resizable histogram, the array of bins expends by adding new bins if the new data point from the training set has a x value that is either greater that the current maximum value or lower than the current minimum value. The width of the bins, step does not change, only the current number of bins. The number of bins expanded until the maximum number of bins, maxNumBins. The method Histogram.<< Implements the expansion of the histogram for a constant bin width.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21 | private def >>(x: Double): Unit =
if(x > max) {
values.appendAll(generate(max, x))
max = values.last._x + step*0.5
}
else if(x < min) {
values.prependAll(generate(min, x))
min = values.head._x - step*0.5
}
final def generate(limit: Double, x: Double): Array[Bin] =
if( x > limit) {
val nBins = ((x-limit)/step).ceil.toInt
var z = limit - step*0.5
Array.tabulate(nBins)(_ => {z += step; new Bin(z) })
}
else {
val nBins = ((limit-x)/step).ceil.toInt
var z = limit + step*0.5
Array.tabulate(nBins)(_ => {z -= step; new Bin(z) })
}
|
Once the maximum number of bins maxNumBins is reached, the histogram is resized by halving the current width of the bin step. The consolidation of the histogram bins is implemented by the method Histogram.>>
1
2
3
4 | private def -- : Unit =
values = (0 until values.size-1 by 2 ./:(new ArrayBuffer[Bin])( (ab, n) =>
ab += (values(n) + values(n+1)) )
|
Testing
Finally, the predictive method is used to compute the accuracy of the model, through the validate method.
The histogram class is tested by loading a data set from file.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18 | def validate(
hist: Histogram,
fileName: String,
eps: Double): Try[Double] = Try {
val src = Source.fromFile(fileName)
val fields = src.getLines.map( _.split(DELIM))
val counts = fields./:((0L, 0L))((s, xy) =>
if( Math.abs(hist.predict(xy(0).trim.toFloat)-xy(1).toFloat) < eps)
(s._1 + 1L, s._2 + 1L)
else
(s._1, s._2 + 1L)
)
val accuracy = counts._1.toDouble/counts._2
src.close
accuracy
}
|
References
- Introduction to Machine Learning Chap 8 Nonparametric methods / Nonparametric density estimation E. Alpaydin- MIT press 2004
- Scala Cookbook A. Alexander O'Reilly 2013
- Histogram-based approximation using Apache Spark
- Introduction to Machine Learning Chap 8 Nonparametric methods / Nonparametric density estimation E. Alpaydin- MIT press 2004
- github.com/patnicolas
No comments:
Post a Comment
Note: Only a member of this blog may post a comment.