Saturday, September 3, 2022

Pattern Matching: Python vs. Scala

Target audience: Beginner
Estimated reading time: 3'   

Ever found yourself frustrated by those pesky chains of if and elif statements? Python's latest versions (3.10 and above) offer a remedy.
This article describes the structured pattern matching [ref 1] and how it relates to its definition and implementation in Scala.


Table of contents

Notes
  • Environments:  Python 3.10, Scala 2.13.2
  • To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statements.

Overview

As a data engineer using Scala for data pre-processing and Python deep learning frameworks, I have always been interested in comparing the features of these two languages.

Python already supports a limited form of this through sequence unpacking assignments. There are two approaches to match a value or patterns in older version of Python (similar to switch/case idiom available in most programming languages):
  • Write a sequence of if/elif/else condition - action pairs.
  • Create a dictionary with key as condition and value as action.
Neither of these options are flexible and easy to maintain. It was a matter of time for Python to join quite a few programming languages in adopting structural pattern matching, in version 3.10 [ref 2].

This new feature is very similar to the pattern matching found in the Scala programming language. Would it be interesting to compare the semantic and extensibility of pattern matching in Python and Scala?

Pattern matching in Scala

Typed pattern matching has been part of the Scala programming language for the last 10 years [ref 3]. The purpose is to check a value against one or several values or patterns. This feature is more powerful than the switch statement found in most of programming language as it  can deconstructs a value or class instance into its constituent parts. 

In the following example the type of a status instance (inherited from trait Status) is matched against all possible types (Failure, Success and Unknown).

sealed trait Status

case class Failure(error: String) extends Status
case object Success extends Status
case object Unknown extends Status
  

def processStatus(status: Status): String = status match {
    case Failure(errorMsg) => errorMsg
    case Success => "Succeeded"
    case Unknown => "Undefined status"
}

Note that the set of types derived from Status is sealed (or restricted). Therefore the function processStatus does not need to handle undefined types (already checked by the compiler).


Python value-type pattern

This is the simplest construct for pattern matching. A value along with its type is checked against a give set of value.

from typing import Any
from enum import Enum

class EnumValue(Enum):
    SUCCESS = 1
    FAILURE = -1
    UNKNOWN = 0


def variable_matching(value: Any):
    match value:
        case 2.0:
            print(f'Input {value} is match as a float')
        case "3.5":
            print(f'Input {value} is match as a string')
        case EnumValue.SUCCESS:
            print(f'Success')
        case _:
            print(f'Failed to match {value}')


if __name__ == '__main__':
    variable_matching(3.5)              # Failed to match 3.5
    variable_matching("3.5")            # 3.5 is matched as a string
    variable_matching(EnumValue.FAILURE)   # Failed to match EnumValue.FAILURE


In the previous code snippet, the argument of the function variable_matching is checked against a set of values AND their types. The attempt to match the input against 3.5 failed because the type is incorrect.

The following truth table illustrates the basic matching algorithm:

Matched type

Matched value

Outcome

No

No

Failed

Yes

No

Failed

No

Yes

Failed

Yes

Yes

Succeed


What about more complex types?

Python mappings pattern

The previous section dealt with matching single value and types. What about more complex structures such as dictionaries.
The following code snippet illustrates the mechanism to match a pattern of key-value pairs.

from typing import Dict, Optional
#  Dict keys: 'name', 'status', 'role', 'bonus'


def mappings_matching(json_dict: Dict) -> Optional[Dict]:
    match json_dict:
       case {'name': 'Joan'}:
          json_dict['status'] = 'vacation'
          return json_dict
       case {'role': 'engineer', 'status': 'promoted'}:
          json_dict['bonus'] = True
          return json_dict
       case _:
          print(f'ERROR: {str(json_dict)} not supported')
          return None



if __name__ == '__main__':
  json_object = {
   'name': 'Joan', 'status': 'full-time', 'role': 'marketing director', 'bonus': False
  }
  print(mappings_matching(json_object))
  # {'name': 'Joan', 'status': 'vacation', 'role': 'marketing director', 'bonus': False}

  json_object = {
   'name': 'Frank', 'status': 'promoted', 'role': 'engineer', 'bonus': False
  }
  print(mappings_matching(json_object))
  # {'name': 'Frank', 'status': 'promoted', 'role': 'engineer', 'bonus': True}

  json_object = {
   'name': 'Frank', 'status': 'promoted', 'role': 'account manager', 'bonus': False
  }
  print(mappings_matching(json_object))
  # ERROR: {'name': 'Frank', 'status': 'promoted', 'role': 'account manager', 'bonus': False} not supported



The function mappings_matching attempts to match a single value Frank for key name then match values for two values, engineer and promoted for the respective keys, role and status.

Python class pattern

The previous section dealt with matching against built-in types. Let's look at custom types or classes that applies an operation such as adding, multiplying of two values.
First we defined a data class, Operator which is fully defined by two parameters:
  • name of the operator with type string
  • arguments, args, of the operator with a type tuple.
In order to succeed, the operation should be 
  • supported (name as key)
  • has exactly two arguments
These two condition defined the context of the pattern matching. The method Operator.__call__ generates the string representation of the operation op (args) (i.e,  + (4, 5)).
The method attempts to match 1) the name of the operator and 2) the number of arguments (which is expected to be 2 for addition and multiplication).

from typing import Any, AnyStr, Tuple
from dataclasses import dataclass


@dataclass
class Operator:
    name: AnyStr
    args: Tuple

    def __call__(self) -> AnyStr:
        match (self.name, len(self.args)):       # Minimum condition for matching
            case ('multiply', 2):
                value = self.args[0]*self.args[1]
                return f'{self.args[0]} x {self.args[1]} = {value}'
            case ('add', 2):
                value = self.args[0] + self.args[1]
                return f'{self.args[0]} + {self.args[1]} = {value}'
            case _:
                return "Undefined operation"

    def __str__(self):
        return f'{self.name}: {str(self.args)}'


if __name__ == '__main__':
operator = Operator(
"add", (3.5, 6.2))
print(operator) #
add: (3.5, 6.2)


Now let's match object of any type to perform the operation. The process follows two steps
  1. Match the type of input to Operator
  2. Match the attributes of the operator by invoking the method Operator.__call__ described above.

def object_matching(obj: Any) -> AnyStr:
    match obj:                                     # First match: Is an operator?
        case Operator('multiply', _):      # Second match: Are operator attributes valid?
            return obj()
        case Operator(_, _):
            return obj()
        case _:
           return f'Type not find {str(obj)}'


if __name__ == '__main__':
    operator = Operator("add", (3.5, 6.2))
    print(object_matching(operator))               # 3.5 + 6.2 = 9.7
    operator = Operator("multiply", (3, 2))
    print(object_matching(operator))               # 3 x 2 = 6
    operator = Operator("multiply", (1, 3, 9))
    print(object_matching(operator)) # Undefined operation
operator = Operator("divided", (3, 3)) print(object_matching(operator)) vvvv# Undefined operation print(object_matching(3.4)) b. # Type not find 3.4



This post illustrates some of the applications of the structural pattern matching feature introduced in version 3.10. There are many more patterns that worth exploring [4].

Thank you for reading this article. For more information ...

References




---------------------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning" Packt Publishing ISBN 978-1-78712-238-3

Monday, July 4, 2022

Manage Memory in Deep Java Library

Target audience: Advanced
Estimated reading time: 4'

This post introduces some techniques to monitor memory usage and leaks in machine learning applications using the Deep Java Learning (DJL) library [ref 1]. This bag of tricks is far from being exhaustive.


Table of contents

DJL is an open source framework to support distributed inference in Java for deep learning models such as MXNet, Tensor flow or PyTorch.
The training of deep learning models may require a very significant amount of floating computations which are best supported by GPUs. However, the memory model in JVM is incompatible with column-based resident memory requires by the GPU. 

Vectorization libraries such as Blast are implemented in C/C++ and support fast execution of linear algebra operations. The ubiquitous Python numerical library, numpy [ref 2] commonly used in data science is a wrapper around these low level math functions. The ND interface, used in DJL, provide Java developers with similar functionality.


NotesThe code snippets in this post are written in Scala but can be easily reworked in Java

The basics

Memory types
DJL supports monitoring 3 memory components
  • Resident Set Size (RSS) is the portion of the memory used by a process that is held in RAM memory and cannot be swapped. 
  • Heap is the section of memory used by object dynamically allocated
  • Non-heap is the section encompassing static memory and stack allocation

Tensor representation

Deep learning frameworks operations on tensors. Those tensors are implemented as NDArray objects, created dynamically from array of values (integer, float,...). NDManager is memory collector/manager native to the underlying C++ implementation of the various deep learning frameworks. Its purpose is to create and delete (close) NDArray instances. NDManager has a hierarchical (single root tree) structure the child manager can be spawn from a parent [ref 3].


Let's consider the following, simple example of the computation of the mean of a sequence of floating point values.
 
 
import ai.djl.ndarray.NDManager

// Set up the memory manager
val ndManager = ndManager.newBaseManager()
    
val input = Array.fill(1000)(Random.nexFloat())
// Allocate resources outside JVM
val ndInput = ndManager.create(input)
val ndMean = ndInput.means()
val mean = ndMean.toFloatArray.head

// Release ND resources
ndManager.close()
 

The steps implemented in the code snippet are:
  1. instantiates the root resource manager, ndManager
  2. creates an array of 1000 random floating point values
  3. convert into a ND array, ndInput
  4. computes the mean, ndMean
  5. convert back to Java data types
  6. and finally close the root manager.

The root NDManager can be broken down it child managers to allow a finer granularity of allocation and release of resources. The following method, computeMean, instantiates a child manager, subNDManager,  to compute the mean value.  The child manager has to be explicitly closed (releasing associated resources) before the function returns.
The memory associated with the local ND variables, ndInput and ndMean are automatically released when going out of scope.

 
import ai.djl.ndarray.NDManager

def computeMean(input: Array[Float], ndManager: NDManager): Float = 
   if(input.nonEmpty) {
      val subNDManager = ndManager.newSubManager()
      val ndInput = ndManager.create(input)
      val ndMean = ndInput.means()
      val mean = ndMean.toFloatArray.head
     
      subNDManager.close()
      mean 
////f// Local resources, ndInput and ndMean are released
     // when going out of scope
  }
  else
     0.0F 
  


JMX to the rescue

The JVM provides developers with the ability to access operating system metrics such as CPU, or heap consumption through the Java Management Extension (JMX) interface [ref 4]
The DJL class, MemoryTrainingListener, leverages JMX monitoring capability, It provides developers with a simple method, collectMemoryInfo to collect metrics

First we need to instruct DJL to enable collection of memory stats as a Java property

  
System.setProperty("collect-memory", "true") 
 

Similarly to the VisualVM heap memory snapshot, described in the next section, we can collect memory metrics (RSS, Heap and NonHeap) before and after each new NDArray object is created or released. 

  
def computeMean(
   input: Array[Float], 
   ndManager: NDManager, 
   metricName: String): Float = {
      
    val manager = ndManager.newSubManager()
    // Initialize a new metrics
    val metrics = new Metrics()

    //  Initialize the collection of memory related metrics
    MemoryTrainingListener.collectMemoryInfo(metrics)
    val initVal = metrics.latestMetric(metricName).getValue.longValue
      
    val ndInput = ndManager.create(input)
    val ndMean = ndInput.mean()

    collectMetric(metrics, initVal, metricName)
    val mean = ndMean.toFloatArray.head

    // Close the output array and collect metrics
    ndMean.close()
    collectMetric(metrics, initVal, metricName)
     
    // Close the input array and collect metrics
    ndInput.close()
    collectMetric(metrics, originalValue, metricName)
      
    // Close the sub manager and collect metrics
    ndManager.close()
    collectMetric(metrics, initVal, metricName) 
    mean
}

First we instantiate a Metrics that is passed along all the various snapshots. Given the metrics and current NDManager, we create a base line in heap memory size, initVal.  We then collect the value of the metric for each creation and release of NDArray instances (collectMetric) from our mean computation example.

Here is a simple snapshot method which compute the increase/decrease in heap memory from the base line.
 
 
def collectMetric(
  metrics: Metrics, 
  initVal: Long, 
  metricName: String): Unit = {

    MemoryTrainingListener.collectMemoryInfo(metrics)  
    val newVal = metrics.latestMetric(metricName).getValue.longValue
    println(s"$metricName: ${(newVal - initVal)/1024.0} KB")
}


Memory leaks detection

I have been a combination of several investigative techniques for estimating the source of a memory leak.

MemoryTrainingListener.debugDump
This method will dump basic memory and CPU stats into a local file for a given metrics

 
  MemoryTrainingListener.debugDump(metrics, outputFile)
  
 
Output
Heap.Bytes:72387328|#Host:10.5.67.192
Heap.Bytes:74484480|#Host:10.5.67.192
NonHeap.Bytes:39337256|#Host:10.5.67.192
NonHeap.Bytes:40466888|#Host:10.5.67.192
cpu.Percent:262.2|#Host:10.5.67.192
cpu.Percent:262.2|#Host:10.5.67.192
rss.Bytes:236146688|#Host:10.5.67.192
rss.Bytes:244297728|#Host:10.5.67.192

NDManager.cap

It is not uncommon to have a NDArray objects associated with a sub manager not been properly closed. One simple solution is to prevent allocating new objects into the parent manager.


 
// Protect the parent/root manager from
// accidental allocation of NDArray objects
 ndManager.cap()

 // Set up the memory manager
 val ndManager = ndManager.newBaseManager()

 val ndInput = ndManager.create(input)

  


Profilers
For reference, DJL introduces a set of experimental profilers to support investigation of memory consumption bottlenecks [ref 5]

VisualVM
We select VisualVM [ref 6] among the various JVM profiling solutions to highlight some key statistics in investigating a memory leak.  VisualVM is a utility that is to be downloaded for Oracle site. It is not bundled with JDK.

A simple way to identify excessive memory consumption is taking regular snapshots or dump of the objects allocated from the heap, as illustrated below.


VisualVM has an intuitive UI to drill down into the sequence or composite objects. Besides quantifying memory consumption during inference, the following details view illustrates the hierarchical nature of the ND manager.






Environments; JDK 11, Scala 2.12.15, Deep Java Library 0.20.0



---------------------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning" Packt Publishing ISBN 978-1-78712-238-3