Showing posts with label Numpy. Show all posts
Showing posts with label Numpy. Show all posts

Monday, October 23, 2023

Supercharge Python with DeepMind's JAX

Target audience: Intermediate
Estimated reading time: 5'
Ever hoped for numpy to offer automatic differentiation and run math computations on a GPU? You might find DeepMind's JAX to be the solution. 

In this piece, we'll delve into JAX's automatic differentiation capabilities and assess how its just-in-time execution compares to numpy.


Follow me on LinkedIn
Notes:
  • Library versions: python 3.11, JAX 0.4.18, Jax-metal 0.0.4 (Mac M1/M2), NumPy 1.26.0, matplotlib 3.8.0
  • To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statements.
  • The performance evaluation Performance: JAX vs NumPy relies on AWS m4.2xlarge EC2 instance for CPU and p3.2xlarge instance equipped with 8 virtual cores, 64GB of memory, and an Nvidia V100 GPU.
  • JAX provides developers with a profiler to generate traces that can be visualized using the Perfetto visualizer.

Introduction

As a quick recall, NumPy stands as a Python library for numerical and scientific computation. It equips data scientists and engineers with capabilities for working with multidimensional arrays, performing speedy array operations, and handling fundamental tasks in linear algebra and statistics [ref 1].

JAX [ref 2] is a numerical computing and machine learning library in Python, developed by DeepMind, that builds upon the foundation of NumPy. JAX offers:

  • Composable function transformations.
  • Auto-vectorization of data batches, enabling parallel processing.
  • First and second-order automatic differentiation for various numerical functions.
  • Just-in-time compilation for GPU execution [ref 3].

Components

  • AutoGrad: Upgraded to improve performance of automatic differentiation.
  • Accelerated Linear Algebra (XLA)JAX uses XLA to compile and run your NumPy code on accelerators.
  • Just-in-time compilation (JIT): Running on XLA
  • Perfetto: Visualization of profiler trace data.

Installation

Here is an overview of basic steps for installing JAX. It is advisable to consult the installation guide as each environment has specific requirements [ref 4].

CPU (MacOS, Linux)
  • pip install --upgrade "jax[cpu]"
GPU (Linux/CUDA)
  1. nvcc --version   # -> ve to be used in the 
  2. pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
GPU (MacOS/mps)
  • python3 -m venv ~/jax-metal
  • source ~/jax-metal/bin/activate
  • python -m pip install jax-metal
  • pip install ml_dtypes==0.2.0
Conda
  • conda install jax -c conda-forge

Automatic differentiation

Overview

Automatic differentiation is a tool that facilitates the automatic calculation of derivatives for a specified mathematical function [ref 5].

This technique efficiently determines precise derivatives by retaining details during the forward pass, which are then utilized during the backward pass. Essentially, 
  • It interprets a code that calculates a function and leverages it to compute the function's derivative. 
  • It crafts a software approach to efficiently determine the derivatives, bypassing the necessity for a closed-form solution.
This article focuses on the Forward Mode Automatic Differentiation, which consists of replacing each primitive operation in the original program by its differential analogue.

To illustrate the concept, let consider the function \[f(x,y,z)=2x^{2}-3xy+z \]  Let's build its forward computation graph:
fig 1. Simplified forward computation graph

Notes
  • This computation graph does not include data type conversion (Python values to JAX or NumPy arrays).
  • The limitation of the forward mode is that the gradient is computed by re-executing the program all over again. The solution is to stored the derivatives to be chained and computed during a backward path: Reverse Model Automatic Differentiation.

Single variable function

Let's implement a class, JaxDifferentiation that wraps the computation of first, second, ... derivatives of a function with a single variable.  \[func: \mathbb{R} \rightarrow \mathbb{R}\].The derivative of various orders are computed in the  constructor. The method __call__ return the list [f, f', f", ..].
class JaxDifferentiation(object):
    """
        Create a set of derivatives of first, second, ... order_derivative order
        :param func Differentiable function 
        :param Order of derivatibes
    """
    def __init__(self, func: Callable[[float], float], order_derivative: int):
        assert order_derivative < 5, f'Order derivatives {order_derivative} should be [0, 4]'
        
        # Build list of derivative f, f', f", ....
        self.derivatives: List[Callable[[float], float]] = [func]
        temp = func
        if order_derivative > 0:
            for order in range(order_derivative):
                # Compute the single variable next order derivative
                temp = jnp.grad(temp)
                self.derivatives.append(temp)


    def __call__(self, x: float) -> List[float]:
        """ Compute derivatives of all orders for value x"""
        return [derivative(x) for derivative in self.derivatives]

Let's compute the derivative of the following function.\[\\ f(x)=2.x^{4}+x^{3} \\ \\ \frac{\mathrm{d} f}{\mathrm{d} x}=8.x^{3}+3.x^{2} \\ \\ \frac{d^{2}f}{dx^{2}}=24.x^{2}+6.x\]. The first and second derivatives are provided for evaluation purpose (Oracle).
# Function definition
def func1(x: float) -> float:
   return 2.0*x**4 + x**3

# First order derivative
def dfunc1(x: float) -> float:
   return 8.0*x**3 + 3*x**2

# Second order derivative
def ddfunc1(x: float) -> float: return 24.0*x**2 + 6.0*x funcs1 = [func1, dfunc1, ddfunc1] jax_differentiation = JaxDifferentiation(func1, len(funcs1)) compared = [f'{oracle}, {jax_value}, {oracle-jax_value}'
                for oracle, jax_value in zip([func(y) for func in funcs1], jax_differentiation(2.0))]
print(compared)
Output
Oracle, Jax,     Difference
40.0,     40.0,    0.0
76.0,     76.0,    0.0
108.0, 108.0,    0.0

Multi-variable function

The next step is to evaluate the computation of partial derivative of a multi-variable function f(x,y,...)\[J=[\frac{\partial f}{\partial x_{1}}, \frac{\partial f}{\partial x_{2}}, ..., \frac{\partial f}{\partial x_{n}}]\]
Let's consider the following function for which the first order partial derivative (Jacobian vector) is provided. \[\\        f(x,y,z)=2x^{2}-3xy+z \\ \\  \frac{\partial f}{\partial x} = (4x -3y) ;  \frac{\partial f}{\partial y} = -3x ; \frac{\partial f}{\partial z} = 1.0\]
# Function definition
def func2(x: List[float]) -> float:
  return 2.0*x[0]*x[0] - 3.0*x[0]*x[1] + x[2]

# Partial derivative over x
def dfunc2_x(x: List[float]) -> float: return 4.0*x[0] - 3.0*x[1]
# Partial derivative over y
def dfunc2_y(x: List[float]) -> float: return -3.0*x[0]
# Partial derivative over z
def dfunc2_z(x: List[float]) -> float: return 1.0

Let's compare the output of the direct computation of the symbolic derivatives (Oracle) dfunc2_x, dfunc2_y and dfunc2_z with the partial derivatives computed by JAX.
We use the forward mode automatic differentiation function jacfwd to compute the gradient [ref 6]. 
# Invoke the Jacobian vector forward function 
dfunc2 = jnp.jacfwd(func2)

y = [2.0, -1.0, 6.0]
derivatives = dfunc2(y)

print(f'df/dx: {derivatives[0]}, {dfunc2_x(y)}\ndf/dy: {derivatives[1]}, {dfunc2_y(y)}\ndf/dz: {derivatives[2]}, {dfunc2_z(y)}'

)
Output
             Oracle, Jax
df/dx:    11.0,     11.0
df/dy:    -6.0,     -6.0
df/dz;    1.0,       1.0

Note: The reverse mode automatic differentiation Jax method, jacrev would have produce the same result.

Performance: JAX vs NumPy

A significant drawback of the NumPy library is its absence of GPU support. The next objective is to measure the performance gains achieved by JAX, with and without its just-in-time compiler, on both CPU and GPU.

To facilitate this, we will establish a class named JaxNumpyData containing two functions: np_func, which utilizes NumPy, and jnp_func, its JAX counterpart. These functions will be applied to datasets of various sizes. The compare method will extract 20 subsets from the initial dataset by employing a basic fraction-based approach.
class JaxNumpyData(object):
    """
        Initialize the numpy and Jax function to process data (arrays)
        :param np_function Numpy numerical function
        :param jnp_function Corresponding Jax numerical function
    """
    def __init__(self,
                 np_func: Callable[[np.array], np.array],
                 jnp_func: Callable[[jnp.array], jnp.array]):
        self.np_func = np_func
        self.jnp_func = jnp_func



    def compare(self, full_data_size: int, func_label: AnyStr):
        """
        Compare the 
        :param full_data_size Size of the original dataset used to extract sub-data set
        :param func_label Label used for performance results and plotting
        """
for index in range(1, 20): fraction = 0.05 * index data_size = int(full_data_size*fraction)

            # Execute on the full_data_size*fraction element using Numpy
x_0 = np.linspace(0.0, 100.0, data_size) result1 = self.map_numpy(x_0, f'numpy_{func_label}')

            # Execute on the full_data_size*fraction element using JAX and JAX-JIT
            x_1 = jnp.linspace(0.0, 100.0, data_size)
            result2 = self.map_jax(x_1, f'jax_{func_label}')
            result3 = self.map_jif(x_1, f'jif_{func_label}')
            
            del x_0, x_1, result1, result2, result3

     
     """ Process numpy array, np_x through numpy function np_func """
@time_it def map_numpy(self, np_x: np.array, label: AnyStr) -> np.array: return self.np_func(np_x)


     """ Process Jax array, jnp_x through Jax function jnp_func """
@time_it def map_jax(self, jnp_x: jnp.array, label: AnyStr) -> jnp.array: return self.jnp_func(jnp_x)


     """ Process Jax array, jnp_x through Jax function jnp_func using JIT """
@time_it def map_jif(self, jnp_x: jnp.array, label: AnyStr) -> jnp.array: from jax import jit return jit(self.jnp_func)(jnp_x)

The method map_numpy (resp. map_jax and map_jit) applies the NumPy method np_func (resp. JAX method jnp_func) to the NumPy array np_array (resp. JAX array jnp_array).

CPU

In this first performance test, we measure the duration to compute \[f(x) = sinh(x)+cos(x)\] on 1,000,000,000 values using NumPy, JAX w/o just in time compiler.
def np_func1(x: np.array) -> np.array:
    return np.sinh(x) + np.cos(x)

def jnp_func1(x: jnp.array) -> jnp.array:
    return jnp.sinh(x) + jnp.cos(x)
The JAX produces a 7 fold performance improvement over NumPy. The just in time processor adds another 35% improvement.


The second latency test computes the mean value 
\[mean=\frac{1}{n} \sum_{1}^{n} x_{i}\] of a NumPy and JAX array for 1,200,000,000 values.
def np_func2(x: np.array) -> np.array:
return np.mean(x)

def jnp_func2(x: jnp.array) -> jnp.array:
return jnp.mean(x)
The just-in-time processor outperforms both NumPy and JAX native library on CPU.


GPU

For this last test, we execute the function \[f(x)=e^{-x} + cos(x))\] over 200,000,000 values on Nvidia V100 GPU.
As anticipated, NumPy is currently running on the CPU of the EC2 instance, which means it cannot match the performance of JAX running on the Nvidia processor.

Conclusion

In summary, JAX offers data scientists and machine learning engineers a high-performance GPU computing tool that significantly outperforms the NumPy library. Our exploration has only touched the surface of JAX's capabilities, and I encourage readers to delve deeper into features like Autobatching, Vectorization, Generalized convolutions, and its integration with PyTorch and TensorFlow.

Thank you for reading this article. For more information ...


References


Appendix

We include the decorator used for timing the execution of the various functions, for reference.
timing_stats = {}

def time_it(func):
    """ Decorator for timing execution of methods """

    def wrapper(*args, **kwargs):
        start = time.time()
        func(*args, **kwargs)
        duration = '{:.3f}'.format(time.time() - start)
        key: AnyStr = args[2]
        print(f'{key}\t{duration} secs.')
        cur_list = timing_stats.get(key)

        if cur_list is None:
            cur_list = [time.time() - start]
        else:
            cur_list.append(time.time() - start)
        timing_stats[key] = cur_list
        return 0
    return wrapper



---------------------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning" Packt Publishing ISBN 978-1-78712-238-3

Sunday, September 17, 2023

Compare Python, NumPy and PyTorch Performance

Target audience: Beginner
Estimated reading time: 4'

Recently, I embarked on a healthcare project that involved extracting diagnostic information from Electronic Health Records. While fine-tuning a BERT model, I noticed some atypical latency behaviors. This prompted me to conduct a performance comparison between Python lists, NumPy arrays, and PyTorch tensors.
The implementation relies on a timer decorator to collect latency values.


Table of contents
Follow me on LinkedIn

Notes
  • The implementation uses Python 3.11
  • To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statements.

Introduction

I assume that most readers are familiar with the various Python, NumPy and PyTorch containers used in this article. But just in case, here is a quick refresh:

Python list
A list in Python is similar to an array in C, C++, Java or Scala except that its elements can have different types.

Python arrays: Array is a container which can hold a fix number of items or elements. Contrary to lists, items of an array should be of the same type. Most of the data structures make use of arrays to implement their algorithms [ref 1].

NumPy arraysA numPy array represents a multidimensional, homogeneous array of fixed-size items. It is implemented as a static buffer of contiguous values of identical types which index can be dynamically modified to generate matrix, tensor or higher dimension numerical structures [ref 2
].

PyTorch tensors: Similarly to numpy array, PyTorch tensors are multi-dimensional arrays containing elements of a single data type. The tensors share the same semantic and operators as NumPy arrays but also support automatic differentiation and support GPU/Cuda math libraries [ref 3].

Timing with decorator

Decorators are very powerful tools in Python since it allows programmers to modify the behavior of a function, method or even a class. Decorators wrap another function in order to extend the behavior of the wrapped function, without permanently modifying it [ref 4].

def timeit(func):
    ''' Decorator for timing execution of methods'''

    def wrapper(*args, **kwargs):
        start = time.time()
        func(*args, **kwargs)
        duration = '{:.3f}'.format(time.time() - start)
        logging.info(f'{args[1]}:{args[3]}\t{duration} secs.')
        return 0

    return wrapper


Benchmark implementation

The objective is to automate the comparison of the various framework and functions by creating a wrapper EvalFunction class.
The evaluation class has two arguments:
  • Descriptive name of the function, func_name used to evaluate the data structures
  • The signature of the function , func used to evaluate the data structures
import array as ar
import time
import numpy as np
from random import Random
from typing import List, AnyStr, Callable, Any, NoReturn
import math
import torch
from dataclasses import dataclass
import logging
from matplotlib import pyplot as plt

collector = {}
@dataclass class EvalFunction: """ Data class for evaluation of Python lists, Array, Numpy array and torch tensor :param func_name Description of the function to execute :param func Lambda to be executed """ func_name: AnyStr func: Callable[[Any], float]  
   def compare(self, input_list: List[float], fraction: float = 0.0) -> NoReturn:
     input_max: int = \
math.floor(len(input_list)*fraction) if 0.0 < fraction <= 1.0 \
else len(input_list)

input_data = input_list[:input_max]

       # Execute lambda through Python list
       self.__execute('python', input_data, 'list:      ')

       # Execute lambda through Python array
       input_array = ar.array('d', input_data)
       self.__execute('python', input_array, 'array:      ')

       # Execute lambda through numpy array
       np_input = np.array(input_list, dtype=np.float32)
       self.__execute('python', np_input, 'lambda: ')

       # Execute native numpy methods
       self.__execute('numpy', np_input, 'native:   ')

       # Execute PyTorch method on CPU
       tensor = torch.tensor(np_input, dtype=torch.float32, device='cpu')
       self.__execute('pytorch', tensor, '(CPU):    ')

       # Execute PyTorch method on GPU
       tensor = torch.tensor(np_input, dtype=torch.float32, device='cuda:0')
       self.__execute('pytorch', tensor, '(CUDA)')


The implementation of the supporting, private method, __execute is described in the Appendix

Evaluation

We've chosen a collection of mathematical transformations that vary in complexity and computational demand to evaluate different frameworks. These transformations involve calculating the mean values produced by the subsequent functions:
\[x_{i}=1+rand{[0, 1]}\]
\[average(x)=\frac{1}{n}\sum_{1}^{n}x_{i}\]
\[sine(x) = average\left ( \sum_{1}^{n}sin\left ( x_{i} \right ) \right )\]
\[sin.exp(x) = average\left ( \sum_{1}^{n}sin\left ( x_{i} \right ) e^{-x_{i}^{2}} \right )\]
\[sin.exp.log(x) = average\left ( \sum_{1}^{n}sin\left ( x_{i} \right ) e^{-x_{i}^{2}} + log(1 + x_{i}))\right )\]

# Functions to evaluate data structures
def average(x) -> float:
    return sum(x)/len(x)
def sine(x) -> float:
    return sum([math.sin(t) for t in x])/len(x)
def sin_exp(x) -> float:
    return sum([math.sin(t)*math.exp(-t) for t in x])/len(x)


# Random value generator
rand = Random(42) num_values = 500_000_000 my_list: List[float] = [1.0 + rand.uniform(0.0, 0.1)] * num_values

# Fraction of the original data set of 500 million data points
fractions = [0.2, 0.4, 0.6, 0.8, 1.0]

# Evaluate the latency for sub data sets of size , len(my_list)*fraction
for fraction in fractions:
eval_average = EvalFunction('sin_exp', average)
eval_average.compare(my_list, fraction)

# x-axis values as size=  len(my_list)*fraction
data_sizes = [math.floor(num_values*fraction) for fraction in fractions]

# Invoke the plotting method
plotter = Plotter(data_sizes, collector)
plotter.plot('Sin*exp 500M')

We conducted the test on an AWS EC2 instance of type p3.2xlarge, equipped with 8 virtual cores, 64GB of memory, and an Nvidia V100 GPU. A basic method for plotting the results is provided in the appendix.

Study 1
We compared the computation time required to determine the {x} -> average{x}  of 500 million real numbers within a Python list, array, NumPy array, and PyTorch tensor.


We compared the computation time required to apply the {x} -> sin{x}.exp{-x} function to 500 million real numbers within a Python list, array, NumPy array, and PyTorch tensor.


Conclusion
  • The performance difference between executing on the GPU versus the CPU becomes more pronounced as the dataset size grows.
  • Predictably, the runtime for both the 'average' and 'sin_exp' functions scales linearly with the size of the dataset when using Python lists or arrays.
  • When executed on the CPU, PyTorch tensors show a 20% performance improvement over NumPy arrays.

Study 2
Le't compare the relative performance of GPU and GPU during the processing of a large PyTorch tensor.


Conclusion
The size of dataset has a very limited impact on the performance of processing PyTorch tensor on GPU while the execution time increases linearly on CPU.

Thank you for reading this article. For more information ...

References


Appendix

The __execute method take two arguments used in the structural pattern match:
  • The framework used to identify the
  • The input data to be processed
The private method __numpy_func applies each the functions (average, sine,...) to a NumPy array, np_array generated from the original list.
The method, __pytorch_func applies each function to a torch tensor derived from np_array.


def __execute(self, framework: AnyStr, input: Any) -> float:
    match framework:
        case 'python':
           return self.func(input)
        case 'numpy':
           return self.__numpy_func(input)
        case 'pytorch':
           return self.__pytorch_func(input)
        case _
           return -1.0


def __numpy_func(self, np_array: np.array) -> float:
   match self.func_name:
      case 'average':
          return np.average(np_array).item()
      case 'sine':
          return np.average(np.sin(np_array)).item()
      case 'sin_exp':
          return np.average(np.sin(np_array)*np.exp(-np_array)).item()


def __pytorch_func(self, tensor: torch.tensor) -> float:
    match self.func_name:
       case 'average':
          return torch.mean(tensor).float()
       case 'sine':
          return torch.mean(torch.sin(tensor)).float()
       case 'sin_exp':
          return torch.mean(torch.sin(tensor) * torch.exp(-tensor)).float()


A simple class, Plotter, to wraps the creation and display of plots using matplotlib.

class Plotter(object):
    markers = ['r--', '^-', '+--', '--', '*-']

    def __init__(self, dataset_sizes: List[int], results_map):
        self.sizes = dataset_sizes
        self.results_map = results_map

    def plot(self, title: AnyStr) -> NoReturn:
        index = 0
        np_sizes = np.array(self.sizes)
        for key, values in self.results_map.items():
            np_values = np.array(values)
            plt.plot(np_sizes, np_values, Plotter.markers[index % len(Plotter.markers)])
            index += 1

        plt.title(title, fontsize=16, fontstyle='italic')
        plt.xlabel('Dataset size', fontsize=13, fontstyle='normal')
        plt.ylabel('time secs', fontsize=13, fontstyle='normal')
        plt.legend(['Python List', 'Python Array', 'Numpy native', 'PyTorch CPU', 'PyTorch GPU'])
        plt.show()


---------------------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning" Packt Publishing ISBN 978-1-78712-238-3