Tuesday, March 4, 2014

Curried and Partial Functions in Scala

Target audience: Intermediate
Estimated reading time: 5'



Table of contents
Follow me on LinkedIn

Introduction

Although most of Scala developers have some level of knowledge of curried and partial functions, they struggle to grasp the different use case either of those functional programming techniques are applied and their relative benefits. For those interested in more detailed explanation of currying existing functions, I would recommend the excellent post of Daniel Westheide.

Note: For the sake of readability of the implementation of algorithms, all non-essential code such as error checking, comments, exception, validation of class and method arguments, scoping qualifiers or import is omitted

Partial functions

Partially defined functions are commonly used to restrict the domain of applicability of function arguments. The restriction can apply to either the type of the argument or its values. Let's consider the computation of square root of a floating point value dsqrt. The value of the argument has to be positive. A simple implementation relies on the Option monad.

def dsqrt(x: Double): Option[Double] = 
    if(x<0.0) None else Some(Math.sqrt(x))

The same method can be implemented using a partial function by applying the matching pattern to the argument as follows.

val zero = 0.0

def dsqrt: PartialFunction[Double, Double]= { 
    case x: Double if(x >=zero) => Math.sqrt(x) 
}

The method dsqrt return an object of type PartialFunction with an input argument of type Double and an output of type Double. The method can handle only input value x >= 0.0. Any other input value generates a MatchErr exception.

Let's evaluate the partial function with different argument types and values. The partial function accepts input with type for which an implicit conversion has been already defined. The first invocation of dsqrt (line 2) returns a valid Partial Function. The second invocation (line 8) triggers an implicit conversion from Long to Double, before returning the partial function. The third call to dsqrt will returns a MatchErr (line 12)

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
  // Succeeds
Try (dsqrt(3.6)) match { 
  case Success(res) => {} 
  case Failure(e) => Console.println(e.toString)
 }

  // Succeed because the implicit conversion Long to Double
Try (_sqrt(4L)) match { }

  // Fails with the following message
  // "throws scala.MatchError: -3.6 (of class java.lang.Double)"
Try (_sqrt(-3.6)) match { }

A similar restriction can be applied to the type of argument. Let's consider the incremental methods add1 (line 3) and add2 (line 15) of class Value. These two methods process values of type AnyValue. It requires that the type of argument to be checked. add and add2 described two alternative and crude type safe checking approaches.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Value(x: Int) {

  def add1(anyVal: AnyVal): AnyVal = {
    if(anyVal.isInstanceOf[Int]) {
      val value = anyVal.asInstanceOf[Int]
     (x + value).asInstanceOf[AnyVal]
    }
    else if (anyVal.isInstanceOf[Double]) {
      val value = anyVal.asInstanceOf[Double].floor.toInt
     (x + value).asInstanceOf[AnyVal]
    }
    else { }
  }
   
  def add2(anyVal: AnyVal): AnyVal = anyVal.getClass.getName match {
    case "Int" => {
      val value = anyVal.asInstanceOf[Int]
     (x + value).asInstanceOf[AnyVal]
    }
    case "Double" => {
      val value = anyVal.asInstanceOf[Double].floor.toInt
     (x + value).asInstanceOf[AnyVal]
    }
    case _ => {}
  }
} 

The two implementation add1 and add2 are cumbersome to say the least. An alternative implementation using a pattern matching on the type an returning a partial function is far more elegant.

1
2
3
4
5
6
7
8
9
class Value(x: Int) {
  def add: PartialFunction[Any, Any] = {
    case n: Int => x + n
    case y: Double => x + y.floor.toInt
  }
}

val value = new Value(4)
Console.println(value.add(4.5))

In the example above, we do not have to handle the case for each the argument has an improper type. The partial function will simply discards it.

Note: The method Actor.receive that define a message loop in an actor, consuming messages from the mail box are indeed partial functions.

Currying

Currying is the transformation of function with multiple arguments into a chain of function taking a single argument. if f: x-> f(x,y) then curry(f): x -> (y->f(x,y))
Let's take a simple example of a sum of two floating point values. The original 2 arguments functions (1) can be converted into a single argument function returning a anonymous function taking the second argument as parameter (2). Scala provides developers with a simple syntax sugar to define the cascade of functions calls (3)

def sum(x: Double, y: Double): Double = x+y
def sum(x: Double): Double = (y: Double) => x+y
def sum(x: Double)(y: Double): Double = x+y
 

Most of high order methods on collections are curried. The following example illustrate the commonly used foldLeft.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
class Collection[T](private val values: Array[T]) {

  def foldLeft[U](u:U, op:(U,T)=>U):U = 
    values.foldLeft[U](u)((u,t)=> op(u,t))
  def foldLeft[U](u:U)(op:(U,T)=>U):U = 
    this.foldLeft(u, op)    
}
 
val myCollection = new Collection[Int](Array[Int](3, 5, 8))
val product = myCollection.foldLeft[Int](0)((prod, x) => prod*x)


Is there any benefits of using curried function instead of functions or methods with multiple arguments? Yes, in the case the type inferencer has more information that the second argument can use. Let's consider the foldLeft method above:
    def foldLeft[U](u: U)(op:(U, T)=>U):U = this.foldLeft(u, op)

The type inferencer determine the type U of the first argument and used it subsequently in the binary operator parameters
op:(U, T) => U

References