Thursday, July 6, 2017

Runge-Kutta ODE Solver in Scala

Target audience: Advanced
Estimated reading time: 6'


This post describes the implementation of the different Runge-Kutta method to solve differential equations in Scala.


Table of contents
Follow me on LinkedIn
Note: Implementation relies on Scala 2.11.8

Overview

The objective is to leverage the functional programming components of the Scala programming language to create a generic solver of ordinary differential equations (ODE) using Runge-Kutta family of approximation algorithms.

Most of ordinary differential equations cannot be solved analytically. In this case, a numeric approximation to the solution is often good enough to solve an engineering problem. Oddly enough most of commonly used algorithm to compute such an approximation have been establish a century ago. Let's consider the differential equation \[\frac{\mathrm{d}y }{\mathrm{d} x} = f(x,y)\] The family of explicit Runge-Kutta numerical approximation methods is defined as \[y_{n+1} = y_{n} + \sum_{i=0}^{s<n}b_{i}k_{i}\\where\,\,k_{j}=h.f(x_{n} + c_{j}h, y_{n} + \sum_{s=1}^{j-1} a_{s,s-1}k_{s-1} )\\with\,\,h=x_{n+1}-x_{n}\,\,and\,\,\Delta = \frac{dy}{dx}+ \sum_{s=1}^{j-1} a_{s,s-1}k_{s-1}\] k(j) is the increment based on the slope at the midpoint of the interval [x(n),x(n+1)] using delta. The Euler method defined as \[y_{n+1} = y_{n} + hf(t_{n},y_{n})\] and 4th order Runge-Kutta \[y_{n+1} = y_{n} + \frac{h}{6}(k_{1} + 2 k_{2}+2 k_{3}+ k_{4})\,\,\,;h = x_{n+1}-x_{n}\\k_{1}=f(x_{n},y_{n})\\k_{2}=f(x_{n}+\frac{h}{2}, y_{n}+\frac{hk_{1}}{2})\\k_{3}=f(x_{n}+\frac{h}{2},y_{n}+\frac{hk_{2}}{2})\\k_{4}=f(x_{n}+h,y_{n}+hk_{3})\]

The implementation relies on the functional aspect of the Scala language and should be flexible enough to support any new future algorithm. The generic Runge-Kutta coefficients a(i), b(i) and c(i) are represented as a matrix: \[\begin{vmatrix} c_{2}\,\,a_{21}\,\,0.0\,\,0.0\,\,...\,\,0.0\,\,0.0\\ c_{3}\,\,a_{31}\,\,a_{32}\,\,0.0\,\,...\,\,0.0\,\,0.0 \\ c_{4}\,\,a_{41}\,\,a_{42}\,\,a_{43}\,\,...\,\,0.0\,\,0.0\\ \\ c_{i}\,\,a_{i1}\,\,a_{i2}\,\,a_{i3}\,\,...\,\,\,\,...\,\,a_{ii-1}\\*\,\,b_{1}\,\,b_{2}\,\,b_{3}\,\,...\,\,...\,\,b_{i-1}\,\,b_{i} \end{vmatrix}\] In order to illustrate the flexibility of this implementation using Scala, I encapsulate the matrix of coefficients of the Euler, 3th order Runge-Kutta, 4th order Runge-Kutta and Felhberg methods using enumeration and case classes.


Coefficients: enumerators and case classes

Java developers, are familiar with enumerators as a data structure to list values without the need to instantiate an iterable collection.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
trait RungeKuttaCoefs {
  type COEFS = Array[Array[Double]]
  
  val EULER = Array(
    Array[Double](0.0, 1.0)
  )

      // Coefficients for Runge-Kutta of order 3
  val RK3 = Array(
    Array[Double](0.0, 0.0,  1/3,  0.0,  0,0), 
    Array[Double](0.5, 0.5,  0.0,  2/3,  0.0),
    Array[Double](1.0, 0.0, -1.0,  0.0,  1/3)
  )
          
    // Coefficients for Runge-Kutta of order 4 
  val RK4 = Array(
    Array[Double](0.0, 0.0, 1/6, 0.0,  0,0,  0.0), 
    Array[Double](0.5, 0.5, 0.0, 1/3,  0.0,  0.0 ),
    Array[Double](0.5, 0.0, 0.5, 0.0,  1/3,  0.0),
    Array[Double](1.0, 0.0, 0.0, 1.0,  0.0,  1/6)
  )
              
    // Coefficients for Runge-Kutta/Felberg of order 5
  val FELBERG = Array(
    Array[Double](0.0, 0.0, 25/216, 0.0, 0.0, 0.0, 0.0, 0.0), 
    Array[Double](0.25, 0.25, 0.0, 0.0, 0.0,  0.0, 0.0, 0.0 ),
    Array[Double](3/8,  3/32, 0.0, 0.0, 1408/2565,  0.0, 0.0, 0.0),
    Array[Double](12/13,1932/2197,-7200/2197, 7296/2197, 0.0,2197/4101,0.0,0.0),
    Array[Double](1.0, 439/216, -8.0, 3680/513,  -845/4104,   0.0, -1/5, 0.0),
    Array[Double](0.5, -8/27,  2.0, -3544/2565, 1859/4104, -11/40, 0.0, 0.0)
  )
    
  val rk = List[COEFS](EULER, RK3, RK4, FELBERG)
}

object RungeKuttaForms extends Enumeration with RungeKuttaCoefs{
  type RungeKuttaForms = Value
  val Euler, Rk3, Rk4, Fehlberg = Value
  
  @inline
  final def getRk(value: Value): COEFS = rk(value.id)
}

The first step is to encapsulate the coefficients of the various versions of the Runge-Kutta formuler (lines 4 & 5), Runge-Kutta order 3 (lines 9 - 12), Runge-Kutta order 4 (lines 16 - 20) and Runge-Kutta-Felberg order 5 (lines 24 - 30) in an Scala enumerator.

The enumerator is at best not elegant. As you browse throught the code snippet above, it is clear that the design to wrap the matrices of coefficients with the enumerator is quite cumbersome. There is a better way: pattern matching. Case classes could be used instead of the singleton enumeration. Setters or getters can optionally be added as in the example below.
The validation of the arguments of methods, exceptions and auxiliary method or variables are omitted for the sake of simplicity.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
trait RKMethods { 
  type COEFS = Array[Array[Double]]
  def getRk(i: Int, j: Int): COEFS
 
object class Euler extends RKMethods{ 
  Array(Array[Double](0.0, 1.0))
  override def getRk(i: Int, j: Int): COEFS {}
}
 
object class RK3 extends RKMethods { 
  Array(
     Array[Double](0.0, 0.0,  1/3, 0.0, 0,0), 
     Array[Double](0.5, 0.5,  0.0, 2/3, 0.0),
     Array[Double](1.0, 0.0, -1.0, 0.0, 1/3)
  )          
  override def getRk(i: Int, j: Int): COEFS {}
}
....

The first step is to encapsulate the coefficients of the various versions of the Runge-Kutta formuler (lines 4 & 5), Runge-Kutta order 3 (lines 9 - 12), Runge-Kutta order 4 (lines 16 - 20) and Runge-Kutta-Felberg order 5 (lines 24 - 30) in an Scala enumerator.

In this second approach, the values of the enumerator are replaced y Euler object (line 5), RK3 - Runge-Kutta order 3 (line 10).


Integration

The main class RungeKutta implements all the components necessary to resolve the ordinary differential equation. This simplified implementation relies on adjustable step for the integration.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
class RungeKutta(
  rungeKutta: RungeKuttaForm,
  adjustStep: (Double, AdjustParameters) => (Double),
  adjustParameters: AdjustParameters) {

  final class StepIntegration(val coefs: Array[Array[Double]]) {}

  def solve(
    xBegin: Double, 
    xEnd: Double, 
    derivative: (Double, Double) => Double): Double
}

The class RungeKutta has three arguments
  • rungeKutta form or type of the Runge-Kutta formula (line 2)
  • adjustStep Metric function to adjust the integration step, dx (line 3)<.li>
  • adjustParameters Parameters used to compute the derivative (line 4)
The computation of the parameters to adjust the integration step, in the code snippet below, is rather simple. A more elaborate implementation would include several alternative formulas implemented as sealed case class

1
2
3
4
5
6
7
8
case class AdjustParameters(
     maxDerivativeValue: Double = 0.01,
     minDerivativeValue: Double = 0.00001,
     gamma: Double = 1.0) {

  lazy val dx0 = 2.0*gamma/(maxDerivativeValue 
         + minDerivativeValue)
}

The sum of the previous Ks value is computed through an inner loop. The outer loop computes all the values for k using the Runge-Kutta matrix coefficients for this particular method. The integration step is implemented as a tail recursion (lines 14 - 22). but an iterative methods using foldLeft can also be used. The tail recursion may not be as effective in this case because it is implemented as a closure: the method has to close on ks.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
final class StepIntegration(val coefs : Array[Array[Double]] ) { 
   
  // Main routine
  def compute(
    x: Double, 
    y: Double, 
    dx: Double,
    derivative: (Double, Double) => Double): Double = {
 
   val ks = new Array[Double](coefs.length)
         
     // Tail recursion closure
   @scala.annotation.tailrec
   def compute(i: Int, k: Double, sum: Double): Double= {
     ks(i) = k
     val sumKs= (0 until i)./:(0.0)((s, j) => s + ks(j)*coefs(i)(j+1))
     val newK = derivative(x + coefs(i)(0)*dx, y + sumKs*dx)
     if( i >= coefs.size)
       sum + newK*coefs(i)(i+2)
     else
       compute(i+1, newK, sum + newK*coefs(i)(i+2))
   }        

   dx*compute(0, 0.0, 0.0) 
}

The next method implements the generic solver that iterates through the entire integration interval. As a matter of fact the solver is indeed implemented as a tail recursion (lines 8 - 20). 

Solver

The accuracy of the solver depends on the value of the increment value dx as computed on line 17. We need to weight the accuracy provided by infinitesimal increment with its computation cost.  Ideally an adaptive algorithm that compute the value dx according the value dy/dx or delta would provide a good compromise between accuracy and cost. The recursion ends when the recursion value x reaches the end of the integration interval (line 14)..

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
def solve(
  xBegin: Double, 
  xEnd: Double, 
  derivative: (Double, Double) => Double): Double ={
  val rungeKutta = new StepIntegration(rungeKuttaForm)
   
  @scala.annotation.tailrec
  def solve(
    x: Double, 
    y: Double, 
    dx: Double, 
    sum: Double): Double = {
    val z = rungeKutta.compute(x, y, dx, derivative)
    if( x >= xEnd)
      sum + z
    else {
      val dx = adjustStep(z - y, adjustParameters)
      solve(x + dx, z, dx, sum+z)
    }
  }
   solve(xBegin, 0.0, adjustParameters.initial, 0.0)
}

The invocation of the solver is very straight forward and can be verified against the analytical solution.
The first step is to define the function that adjusts the integration step (lines 1, 10). This implementation uses the default adjust parameters (line 19) in the initialization of the solver (lines 16 - 19).
 

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
val adjustingStep = 
(diff: Double, adjustParams: AdjustParameters) => {

  val dx = Math.abs(diff)*adjustParams.dx0/adjustParams.gamma
  if( dx < adjustParams.minDerivativeValue) 
    adjustParams.minDerivativeValue
  else if ( dx > adjustParams.maxDerivativeValue)
    adjustParams.maxDerivativeValue
  else
    dx
}

final val x0 = 0.0
final val xEnd = 2.0

val solver = new RungeKutta(
  Rk4, 
  adjustingStep, 
  AdjustParameters())

solver.solve(
  x0, 
  xEnd, 
  (x: Double, y: Double) => Math.exp(-x*x))

The family of explicit Runge-Kutta methods provides a good starting point to resolve ordinary differential equations. The current implementation could and possibly should be extended to support adaptive dx managed by a control loop using a reinforcement learning algorithm of a Kalman filter of just simple exponential moving average.

References

  • The Numerical Solution of Differential-Algebraic Systems by Runge-Kutta Methods - E. Hairer, C Lubich, M. Roche - Springer - 1989 
  • Programming in Scala - M. Odersky, L. Spoon, B. Venners - Artima Press 2008
  • github.com/patnicolas

Sunday, June 11, 2017

Immutability & Covariance in Scala 2.x

Target audience: Beginner
Estimated reading time: 4'

This posts illustrates the concept of immutable and covariant containers/collections in Scala in the case of the stack data structure. 

Note: This article uses Scala 2.11.6

Overview

There is a relation between immutability and covariance which may not be apparent to a novice Scala programmer. Let's consider the case of a mutable and immutable implementation of a stack. The mutable stack is a container of elements with method to push element into (pop the last element from) the stack.

class MutableStack[T]  {
  private[this] val _stack = new ListBuffer[T]
  
  final def pop: Option[T]= 
    if(_stack.isEmpty) 
      None 
    else 
      Some(_stack.remove(_stack.size-1))
  
   def push(t: T): Unit = _stack.append(t)
}

The internal container is defined as a ListBuffer instance. The elements are appended to the list buffer (push) and the method pop pops the last elements pushed onto the stack.
This implementation has a major inconvenient: It cannot accept elements of type other than T because ListBuffer is a invariant collection. Let's consider then a immutable stack

Immutability and covariance

An covariant immutable stack cannot access its elements unless its elements are contained by itself. This feat is accomplish by breaking down the stack recursively as the last element pushed into the stack and the previous state of the stack.

class ImmutableStack[+T](
   val t: T, 
   val stack: Option[ImmutableStack[T]]) {

  def this(t: T) = this(t, Some(new EmptyImmutableStack(t)))
  ...
}

In this recursive approach the immutable stack is initialized with a single element of type T and the option of the existing immutable stack. The stack can be defined as reusable with covariance because elements are managed by the stack itself stack.
The next step is to define the initial state of the stack. We could have chosen a singleton empty stack with no elements. Instead, we define the first state of the immutable stack as:

 
class EmptyImmutableStack[+T](t: T) extends ImmutableStack[T](t, None) 
 

Next let's define the pop and push operators for ImmutableStack. The pop method return the previous state of the immutable stack that is next to last element pushed into the stack. The push method is contra-variant as its push an element of super type of T. The existing state this stack is added as the previous (2nd argument) state.

final def pop: Option[ImmutableStack[T]] = 
      stack.map(sk => new ImmutableStack[T](sk.t, sk.stack))

def push[U >: T](u: U): ImmutableStack[U] = new ImmutableStack[U](u, Some(this))


Tail recursion

The next step is to traverse the entire stack and return a list of all its element. This is accomplished through a tail recursion on the state of the stack

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
def popAll[U >: T]: List[U] = pop(this,List[U]())
 
@scala.annotation.tailrec
private def pop[U >: T](
    _stck: ImmutableStack[U], 
    xs: List[U]
): List[U] = _stck match { 
  
  case st: EmptyImmutableStack[T] => xs.reverse
  case st: ImmutableStack[T] => {
     val newStack = _stck.stack.getOrElse(
        new EmptyImmutableStack[T](_stck.t)
     )
     pop(newStack, _stck.t :: xs)
  }
}

The recursion call pop (line 4) updates the list xs (line 6) and exists when the ImmutableStack is empty of type EmptyImmutableStack (line 9). The list has to be reversed to index the list elements from the last to the first (line 9). As long as the stack is not empty (or type ImmutableStack) the method recurses (line 14).
It is time to test drive this immutable stack.

 
val intStack = new ImmutableStack[Int](4)
val newStack = intStack.push(56).push(14).push(77)
 
println(newStack.popAll.mkString(", "))

The values in the stack are: 77, 14, 56, 4.
This examples illustrates the concept of immutable, covariant stack by using the instance of the stack has its state (current list of elements it contains).

References

Sunday, April 2, 2017

Recursive Minimum Spanning Tree in Scala

Target audience: Intermediate
Estimated reading time: 6'

Determining the best way to link nodes is frequently encountered in network design, transport ventures, and electrical circuitry. This piece explores and showcases an efficient computation for the minimum spanning tree (MST) through the use of Prim's algorithm, which is built on tail recursion.
This article assumes a very minimal understanding of undirected graphs.

Note: Implementation relies on Scala 2.11.8

Overview

Each connectivity in a graph is usually defined as a weight (cost, length, time...). The purpose is to compute the schema that connects all the nodes that minimize the total weight. This problem is known as the minimum spanning tree or MST related to the nodes connected through an un-directed graph [ref 1].

Several algorithms have been developed over the last 70 years to extract the MST from a graph. This post focuses on the implementation of the Prim's algorithm in Scala.

There are many excellent tutorials on graph algorithm and more specifically on the Prim's algorithm. I recommend Lecture 7: Minimum Spanning Trees and Prim’s Algorithm [ref 2].

Let's PQ is a priority queue, a Graph G(V, E) with n vertices V and E edges w(u,v). A Vertex v is defined by 
  • An identifier
  • A load factor, load(v)
  • A parent tree(v)
  • The adjacent vertices adj(v)
The Prim's algorithm can be easily expressed as a simple iterative process. It consists of using a priority queue of all the vertices in the graph and update their load to select the next node in the spanning tree. Each node is popped up (and removed) from the priority queue before being inserted in the tree.
PQ <- V(G)
foreach u in PQ
   load(u) <- INFINITY
 
while PQ nonEmpty
   do u <- v in adj(u)
      if v in PQ && load(v) < w(u,v)
      then
         tree(v) <- u
         load(v) <- w(u,v)
The Scala implementation relies on a tail recursion to transfer vertices from the priority queue to the spanning tree.

Graph definition

The first step is to define a graph structure with edges and vertices [ref 3]. The graph takes two arguments:
  • numVertices number of vertices
  • start index of the root of the minimum spanning tree
The vertex class has three attributes
  • id identifier (arbitrary an integer)
  • load dynamic load (or key) on the vertex
  • tree reference to the minimum spanning tree
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
final class Graph(numVertices: Int, start: Int = 0) {
 
  class Vertex(val id: Int, 
     var load: Int = Int.MaxValue, 
     var tree: Int = -1) 

  val vertices = List.tabulate(numVertices)(new Vertex(_))
  vertices.head.load = 0
  val edges = new HashMap[Vertex, HashMap[Vertex, Int]]

  def += (from: Int, to: Int, weight: Int): Unit = {
    val fromV = vertices(from)
    val toV = vertices(to)
    connect(fromV, toV, weight)
    connect(toV, fromV, weight)
  }

  def connect(from: Vertex, to: Vertex, weight: Int): Unit = {
    if( !edges.contains(from))
      edges.put(from, new HashMap[Vertex, Int])    
    edges.get(from).get.put(to, weight)
  }   
  // ...
}

The vertices are initialized by with a unique identifier id, and a default load Int.MaxValue and a default depth tree (lines 3-5). The Vertex class resides within the scope of the outer class Graph to avoid naming conflict. The vertices are managed through a linked list (line 7) while the edges are defined as hash maps with a map of other edges as value (line 9). The operator += add a new edge between two existing vertices with a specified load (line 11) 
In most case, the identifier is a character string or a data structure. As described in the pseudo-code, the load for the root of the spanning tree is defined a 0.

The load is defined as an integer for performance's sake. It is recommended to convert (quantization) a floating-point value to an integer for the processing of very large graph, then convert back to a original format on the resulting minimum spanning tree.
The edges are defined as hash table with the source vertex as key and the hash table of destination vertex and edge weight as value. 


The graph is un-directed therefore the connection initialized in the method
+= are bi-directional.

Priority queue

The priority queue is used to re-order the vertices and select the next vertex to be added to the spanning tree.

Note: There are many different implementation of priority queues in Scala and Java. You need to keep in mind that the Prim's algorithm requires the queue to be re-ordered after its load is updated (see pseudo-code). The PriorityQueue classes in the Scala and Java libraries do not allow elements to be removed or to be explicitly re-ordered. An alternative is to use a binary tree, red-black tree for which elements can be removed and the tree re-balanced.

The implementation of the priority has an impact on the time complexity of the algorithm. The following implementation of the priority queue is provided only to illustrate the Prim's algorithm. 

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
class PQueue(vertices: List[Vertex]) {
   var queue = vertices./:(new PriorityQueue[Vertex])((pq, v) => pq += v)
    
   def += (vertex: Vertex): Unit = queue += vertex
   def pop: Vertex = queue.dequeue
   def sort: Unit = {}
   def push(vertex: Vertex): Unit = queue.enqueue(vertex)
   def nonEmpty: Boolean = queue.nonEmpty
}
  
type MST = ListBuffer[Int]
implicit def orderingByLoad[T <: Vertex]: Ordering[T] = Ordering.by( - _.load)  


The Scala PriorityQueue class required the implicit ordering of vertices using their load (line 2). This accomplished by defining an implicit conversion of a type T with upper-bound type Vertex to Ordering[T] (line 12).

Notes
  • The type T has to be a sub-class of Vertex. A direct conversion from Vertex type to Ordering[Vertex] is not allowed in Scala.
  • We use the PriorityQueue from the Java library as it provides more flexibility than the Scala TreeSet.

Prim's algorithm

This implementation is the direct translation of the pseudo-code presented in the second paragraph. It relies on the efficient Scala tail recursion (line 5).

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def prim: List[Int] = {
  val queue = new PQueue(vertices)
   
  @scala.annotation.tailrec
  def prim(parents: MST): Unit = {
    if( queue.nonEmpty ) {
      val head = queue.pop
      val candidates = edges.get(head).get
          .filter{ 
            case(vt,w) => vt.tree == -1 && w <= vt.load
          }
 
      if( candidates.nonEmpty ) {
        candidates.foreach {case (vt, w) => vt.load = w }
        queue.sort
      }
      parents.append(head.id)
      head.tree = 1
      prim(parents)
    }
  }
  val parents = new MST
  prim(parents)
  parents.toList
}

As long as the priority queue is not empty (line 6), the next element is the priority queue is retrieved (line 7) for which is select the most appropriate candidate for the next vertex (line 8 - 11). The load of each candidate is updated (line 14) and the priority queue is re-sorted (line 15).
Although a tree set is a more efficient data structure for managing the set of vertices waiting to be weighted, it does not allow the existing priority queue to be resorted because of its immutability.

Time complexity

As mentioned earlier, the time complexity depends on the implementation of the priority queue. If E is the number of edges, and V the number of vertices:
  • Minimum spanning tree with linear queue: V2
  • Minimum spanning tree with binary heap: (E + V).LogV
  • Minimum spanning tree with Fibonacci heap: V.LogV
Note: See Summary of time complexity of algorithms for details.


References

[1Introduction to Algorithms Chapter 24 Minimum Spanning Trees - T. Cormen, C. Leiserson, R. Rivest - MIT Press 1989
[2Lecture 7: Minimum Spanning Trees and Prim’s Algorithm Dekai Wu, Department of Computer Science and Engineering - The Hong Kong University of Science & Technology
[3] Graph Theory Chapter 4 Optimization Involving Tree - V.K. Balakrishnan - Schaum's Outlines Series, McGraw Hill, 1997