- Library versions: python 3.11, JAX 0.4.18, Jax-metal 0.0.4 (Mac M1/M2), NumPy 1.26.0, matplotlib 3.8.0
- To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statements.
- The performance evaluation Performance: JAX vs NumPy relies on AWS m4.2xlarge EC2 instance for CPU and p3.2xlarge instance equipped with 8 virtual cores, 64GB of memory, and an Nvidia V100 GPU.
- JAX provides developers with a profiler to generate traces that can be visualized using the Perfetto visualizer.
Introduction
JAX [ref 2] is a numerical computing and machine learning library in Python, developed by DeepMind, that builds upon the foundation of NumPy. JAX offers:
- Composable function transformations.
- Auto-vectorization of data batches, enabling parallel processing.
- First and second-order automatic differentiation for various numerical functions.
- Just-in-time compilation for GPU execution [ref 3].
Components
- AutoGrad: Upgraded to improve performance of automatic differentiation.
- Accelerated Linear Algebra (XLA): JAX uses XLA to compile and run your NumPy code on accelerators.
- Just-in-time compilation (JIT): Running on XLA
- Perfetto: Visualization of profiler trace data.
Installation
- pip install --upgrade "jax[cpu]"
- nvcc --version # -> ve to be used in the
- pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
- python3 -m venv ~/jax-metal
- source ~/jax-metal/bin/activate
- python -m pip install jax-metal
- pip install ml_dtypes==0.2.0
- conda install jax -c conda-forge
Automatic differentiation
Overview
- It interprets a code that calculates a function and leverages it to compute the function's derivative.
- It crafts a software approach to efficiently determine the derivatives, bypassing the necessity for a closed-form solution.
- This computation graph does not include data type conversion (Python values to JAX or NumPy arrays).
- The limitation of the forward mode is that the gradient is computed by re-executing the program all over again. The solution is to stored the derivatives to be chained and computed during a backward path: Reverse Model Automatic Differentiation.
Single variable function
class JaxDifferentiation(object):
"""
Create a set of derivatives of first, second, ... order_derivative order
:param func Differentiable function
:param Order of derivatibes
"""
def __init__(self, func: Callable[[float], float], order_derivative: int):
assert order_derivative < 5, f'Order derivatives {order_derivative} should be [0, 4]'
# Build list of derivative f, f', f", ....
self.derivatives: List[Callable[[float], float]] = [func]
temp = func
if order_derivative > 0:
for order in range(order_derivative):
# Compute the single variable next order derivative
temp = jnp.grad(temp)
self.derivatives.append(temp)
def __call__(self, x: float) -> List[float]:
""" Compute derivatives of all orders for value x"""
return [derivative(x) for derivative in self.derivatives]
# Function definition
def func1(x: float) -> float:
return 2.0*x**4 + x**3
# First order derivative
def dfunc1(x: float) -> float:
return 8.0*x**3 + 3*x**2
# Second order derivative
def ddfunc1(x: float) -> float:
return 24.0*x**2 + 6.0*x
funcs1 = [func1, dfunc1, ddfunc1]
jax_differentiation = JaxDifferentiation(func1, len(funcs1))
compared = [f'{oracle}, {jax_value}, {oracle-jax_value}'
for oracle, jax_value in zip([func(y) for func in funcs1], jax_differentiation(2.0))]
print(compared)
Multi-variable function
# Function definition
def func2(x: List[float]) -> float:
return 2.0*x[0]*x[0] - 3.0*x[0]*x[1] + x[2]
# Partial derivative over x
def dfunc2_x(x: List[float]) -> float: return 4.0*x[0] - 3.0*x[1]
# Partial derivative over y
def dfunc2_y(x: List[float]) -> float: return -3.0*x[0]
# Partial derivative over z
def dfunc2_z(x: List[float]) -> float:
return 1.0
# Invoke the Jacobian vector forward function
dfunc2 = jnp.jacfwd(func2)
y = [2.0, -1.0, 6.0]
derivatives = dfunc2(y)
print(f'df/dx: {derivatives[0]}, {dfunc2_x(y)}\ndf/dy: {derivatives[1]}, {dfunc2_y(y)}\ndf/dz: {derivatives[2]}, {dfunc2_z(y)}'
)
Oracle, Jax
Performance: JAX vs NumPy
class JaxNumpyData(object):
"""
Initialize the numpy and Jax function to process data (arrays)
:param np_function Numpy numerical function
:param jnp_function Corresponding Jax numerical function
"""
def __init__(self,
np_func: Callable[[np.array], np.array],
jnp_func: Callable[[jnp.array], jnp.array]):
self.np_func = np_func
self.jnp_func = jnp_func
def compare(self, full_data_size: int, func_label: AnyStr):
"""
Compare the
:param full_data_size Size of the original dataset used to extract sub-data set
:param func_label Label used for performance results and plotting
"""
for index in range(1, 20):
fraction = 0.05 * index
data_size = int(full_data_size*fraction)
# Execute on the full_data_size*fraction element using Numpy
x_0 = np.linspace(0.0, 100.0, data_size)
result1 = self.map_numpy(x_0, f'numpy_{func_label}')
# Execute on the full_data_size*fraction element using JAX and JAX-JIT
x_1 = jnp.linspace(0.0, 100.0, data_size)
result2 = self.map_jax(x_1, f'jax_{func_label}')
result3 = self.map_jif(x_1, f'jif_{func_label}')
del x_0, x_1, result1, result2, result3
""" Process numpy array, np_x through numpy function np_func """
@time_it
def map_numpy(self, np_x: np.array, label: AnyStr) -> np.array:
return self.np_func(np_x)
""" Process Jax array, jnp_x through Jax function jnp_func """
@time_it
def map_jax(self, jnp_x: jnp.array, label: AnyStr) -> jnp.array:
return self.jnp_func(jnp_x)
""" Process Jax array, jnp_x through Jax function jnp_func using JIT """
@time_it
def map_jif(self, jnp_x: jnp.array, label: AnyStr) -> jnp.array:
from jax import jit
return jit(self.jnp_func)(jnp_x)
CPU
def np_func1(x: np.array) -> np.array:
return np.sinh(x) + np.cos(x)
def jnp_func1(x: jnp.array) -> jnp.array:
return jnp.sinh(x) + jnp.cos(x)
def np_func2(x: np.array) -> np.array:
return np.mean(x)def jnp_func2(x: jnp.array) -> jnp.array:return jnp.mean(x)
GPU
Conclusion
Thank you for reading this article. For more information ...
References
Appendix
timing_stats = {}
def time_it(func):
""" Decorator for timing execution of methods """
def wrapper(*args, **kwargs):
start = time.time()
func(*args, **kwargs)
duration = '{:.3f}'.format(time.time() - start)
key: AnyStr = args[2]
print(f'{key}\t{duration} secs.')
cur_list = timing_stats.get(key)
if cur_list is None:
cur_list = [time.time() - start]
else:
cur_list.append(time.time() - start)
timing_stats[key] = cur_list
return 0
return wrapper
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning" Packt Publishing ISBN 978-1-78712-238-3