Sunday, September 17, 2023

Supercharge Python with DeepMind's JAX

Target audience: Intermediate
Estimated reading time: 5'
Ever hoped for numpy to offer automatic differentiation and run math computations on a GPU? You might find DeepMind's JAX to be the solution. 

In this piece, we'll delve into JAX's automatic differentiation capabilities and assess how its just-in-time execution compares to numpy.


Follow me on LinkedIn
Notes:
  • Library versions: python 3.11, JAX 0.4.18, Jax-metal 0.0.4 (Mac M1/M2), NumPy 1.26.0, matplotlib 3.8.0
  • To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statements.
  • The performance evaluation Performance: JAX vs NumPy relies on AWS m4.2xlarge EC2 instance for CPU and p3.2xlarge instance equipped with 8 virtual cores, 64GB of memory, and an Nvidia V100 GPU.
  • JAX provides developers with a profiler to generate traces that can be visualized using the Perfetto visualizer.

Introduction

As a quick recall, NumPy stands as a Python library for numerical and scientific computation. It equips data scientists and engineers with capabilities for working with multidimensional arrays, performing speedy array operations, and handling fundamental tasks in linear algebra and statistics [ref 1].

JAX [ref 2] is a numerical computing and machine learning library in Python, developed by DeepMind, that builds upon the foundation of NumPy. JAX offers:

  • Composable function transformations.
  • Auto-vectorization of data batches, enabling parallel processing.
  • First and second-order automatic differentiation for various numerical functions.
  • Just-in-time compilation for GPU execution [ref 3].

Components

  • AutoGrad: Upgraded to improve performance of automatic differentiation.
  • Accelerated Linear Algebra (XLA)JAX uses XLA to compile and run your NumPy code on accelerators.
  • Just-in-time compilation (JIT): Running on XLA
  • Perfetto: Visualization of profiler trace data.

Installation

Here is an overview of basic steps for installing JAX. It is advisable to consult the installation guide as each environment has specific requirements [ref 4].

CPU (MacOS, Linux)
  • pip install --upgrade "jax[cpu]"
GPU (Linux/CUDA)
  1. nvcc --version   # -> ve to be used in the 
  2. pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
GPU (MacOS/mps)
  • python3 -m venv ~/jax-metal
  • source ~/jax-metal/bin/activate
  • python -m pip install jax-metal
  • pip install ml_dtypes==0.2.0
Conda
  • conda install jax -c conda-forge

Automatic differentiation

Overview

Automatic differentiation is a tool that facilitates the automatic calculation of derivatives for a specified mathematical function [ref 5].

This technique efficiently determines precise derivatives by retaining details during the forward pass, which are then utilized during the backward pass. Essentially, 
  • It interprets a code that calculates a function and leverages it to compute the function's derivative. 
  • It crafts a software approach to efficiently determine the derivatives, bypassing the necessity for a closed-form solution.
This article focuses on the Forward Mode Automatic Differentiation, which consists of replacing each primitive operation in the original program by its differential analogue.

To illustrate the concept, let consider the function \[f(x,y,z)=2x^{2}-3xy+z \]  Let's build its forward computation graph:
fig 1. Simplified forward computation graph

Notes
  • This computation graph does not include data type conversion (Python values to JAX or NumPy arrays).
  • The limitation of the forward mode is that the gradient is computed by re-executing the program all over again. The solution is to stored the derivatives to be chained and computed during a backward path: Reverse Model Automatic Differentiation.

Single variable function

Let's implement a class, JaxDifferentiation that wraps the computation of first, second, ... derivatives of a function with a single variable.  \[func: \mathbb{R} \rightarrow \mathbb{R}\].The derivative of various orders are computed in the  constructor. The method __call__ return the list [f, f', f", ..].
class JaxDifferentiation(object):
    """
        Create a set of derivatives of first, second, ... order_derivative order
        :param func Differentiable function 
        :param Order of derivatibes
    """
    def __init__(self, func: Callable[[float], float], order_derivative: int):
        assert order_derivative < 5, f'Order derivatives {order_derivative} should be [0, 4]'
        
        # Build list of derivative f, f', f", ....
        self.derivatives: List[Callable[[float], float]] = [func]
        temp = func
        if order_derivative > 0:
            for order in range(order_derivative):
                # Compute the single variable next order derivative
                temp = jnp.grad(temp)
                self.derivatives.append(temp)


    def __call__(self, x: float) -> List[float]:
        """ Compute derivatives of all orders for value x"""
        return [derivative(x) for derivative in self.derivatives]

Let's compute the derivative of the following function.\[\\ f(x)=2.x^{4}+x^{3} \\ \\ \frac{\mathrm{d} f}{\mathrm{d} x}=8.x^{3}+3.x^{2} \\ \\ \frac{d^{2}f}{dx^{2}}=24.x^{2}+6.x\]. The first and second derivatives are provided for evaluation purpose (Oracle).
# Function definition
def func1(x: float) -> float:
   return 2.0*x**4 + x**3

# First order derivative
def dfunc1(x: float) -> float:
   return 8.0*x**3 + 3*x**2

# Second order derivative
def ddfunc1(x: float) -> float: return 24.0*x**2 + 6.0*x funcs1 = [func1, dfunc1, ddfunc1] jax_differentiation = JaxDifferentiation(func1, len(funcs1)) compared = [f'{oracle}, {jax_value}, {oracle-jax_value}'
                for oracle, jax_value in zip([func(y) for func in funcs1], jax_differentiation(2.0))]
print(compared)
Output
Oracle, Jax,     Difference
40.0,     40.0,    0.0
76.0,     76.0,    0.0
108.0, 108.0,    0.0

Multi-variable function

The next step is to evaluate the computation of partial derivative of a multi-variable function f(x,y,...)\[J=[\frac{\partial f}{\partial x_{1}}, \frac{\partial f}{\partial x_{2}}, ..., \frac{\partial f}{\partial x_{n}}]\]
Let's consider the following function for which the first order partial derivative (Jacobian vector) is provided. \[\\        f(x,y,z)=2x^{2}-3xy+z \\ \\  \frac{\partial f}{\partial x} = (4x -3y) ;  \frac{\partial f}{\partial y} = -3x ; \frac{\partial f}{\partial z} = 1.0\]
# Function definition
def func2(x: List[float]) -> float:
  return 2.0*x[0]*x[0] - 3.0*x[0]*x[1] + x[2]

# Partial derivative over x
def dfunc2_x(x: List[float]) -> float: return 4.0*x[0] - 3.0*x[1]
# Partial derivative over y
def dfunc2_y(x: List[float]) -> float: return -3.0*x[0]
# Partial derivative over z
def dfunc2_z(x: List[float]) -> float: return 1.0

Let's compare the output of the direct computation of the symbolic derivatives (Oracle) dfunc2_x, dfunc2_y and dfunc2_z with the partial derivatives computed by JAX.
We use the forward mode automatic differentiation function jacfwd to compute the gradient [ref 6]. 
# Invoke the Jacobian vector forward function 
dfunc2 = jnp.jacfwd(func2)

y = [2.0, -1.0, 6.0]
derivatives = dfunc2(y)

print(f'df/dx: {derivatives[0]}, {dfunc2_x(y)}\ndf/dy: {derivatives[1]}, {dfunc2_y(y)}\ndf/dz: {derivatives[2]}, {dfunc2_z(y)}'

)
Output
             Oracle, Jax
df/dx:    11.0,     11.0
df/dy:    -6.0,     -6.0
df/dz;    1.0,       1.0

Note: The reverse mode automatic differentiation Jax method, jacrev would have produce the same result.

Performance: JAX vs NumPy

A significant drawback of the NumPy library is its absence of GPU support. The next objective is to measure the performance gains achieved by JAX, with and without its just-in-time compiler, on both CPU and GPU.

To facilitate this, we will establish a class named JaxNumpyData containing two functions: np_func, which utilizes NumPy, and jnp_func, its JAX counterpart. These functions will be applied to datasets of various sizes. The compare method will extract 20 subsets from the initial dataset by employing a basic fraction-based approach.
class JaxNumpyData(object):
    """
        Initialize the numpy and Jax function to process data (arrays)
        :param np_function Numpy numerical function
        :param jnp_function Corresponding Jax numerical function
    """
    def __init__(self,
                 np_func: Callable[[np.array], np.array],
                 jnp_func: Callable[[jnp.array], jnp.array]):
        self.np_func = np_func
        self.jnp_func = jnp_func



    def compare(self, full_data_size: int, func_label: AnyStr):
        """
        Compare the 
        :param full_data_size Size of the original dataset used to extract sub-data set
        :param func_label Label used for performance results and plotting
        """
for index in range(1, 20): fraction = 0.05 * index data_size = int(full_data_size*fraction)

            # Execute on the full_data_size*fraction element using Numpy
x_0 = np.linspace(0.0, 100.0, data_size) result1 = self.map_numpy(x_0, f'numpy_{func_label}')

            # Execute on the full_data_size*fraction element using JAX and JAX-JIT
            x_1 = jnp.linspace(0.0, 100.0, data_size)
            result2 = self.map_jax(x_1, f'jax_{func_label}')
            result3 = self.map_jif(x_1, f'jif_{func_label}')
            
            del x_0, x_1, result1, result2, result3

     
     """ Process numpy array, np_x through numpy function np_func """
@time_it def map_numpy(self, np_x: np.array, label: AnyStr) -> np.array: return self.np_func(np_x)


     """ Process Jax array, jnp_x through Jax function jnp_func """
@time_it def map_jax(self, jnp_x: jnp.array, label: AnyStr) -> jnp.array: return self.jnp_func(jnp_x)


     """ Process Jax array, jnp_x through Jax function jnp_func using JIT """
@time_it def map_jif(self, jnp_x: jnp.array, label: AnyStr) -> jnp.array: from jax import jit return jit(self.jnp_func)(jnp_x)

The method map_numpy (resp. map_jax and map_jit) applies the NumPy method np_func (resp. JAX method jnp_func) to the NumPy array np_array (resp. JAX array jnp_array).

CPU

In this first performance test, we measure the duration to compute \[f(x) = sinh(x)+cos(x)\] on 1,000,000,000 values using NumPy, JAX w/o just in time compiler.
def np_func1(x: np.array) -> np.array:
    return np.sinh(x) + np.cos(x)

def jnp_func1(x: jnp.array) -> jnp.array:
    return jnp.sinh(x) + jnp.cos(x)
The JAX produces a 7 fold performance improvement over NumPy. The just in time processor adds another 35% improvement.


The second latency test computes the mean value 
\[mean=\frac{1}{n} \sum_{1}^{n} x_{i}\] of a NumPy and JAX array for 1,200,000,000 values.
def np_func2(x: np.array) -> np.array:
return np.mean(x)

def jnp_func2(x: jnp.array) -> jnp.array:
return jnp.mean(x)
The just-in-time processor outperforms both NumPy and JAX native library on CPU.


GPU

For this last test, we execute the function \[f(x)=e^{-x} + cos(x))\] over 200,000,000 values on Nvidia V100 GPU.
As anticipated, NumPy is currently running on the CPU of the EC2 instance, which means it cannot match the performance of JAX running on the Nvidia processor.

Conclusion

In summary, JAX offers data scientists and machine learning engineers a high-performance GPU computing tool that significantly outperforms the NumPy library. Our exploration has only touched the surface of JAX's capabilities, and I encourage readers to delve deeper into features like Autobatching, Vectorization, Generalized convolutions, and its integration with PyTorch and TensorFlow.

Thank you for reading this article. For more information ...


References


Appendix

We include the decorator used for timing the execution of the various functions, for reference.
timing_stats = {}

def time_it(func):
    """ Decorator for timing execution of methods """

    def wrapper(*args, **kwargs):
        start = time.time()
        func(*args, **kwargs)
        duration = '{:.3f}'.format(time.time() - start)
        key: AnyStr = args[2]
        print(f'{key}\t{duration} secs.')
        cur_list = timing_stats.get(key)

        if cur_list is None:
            cur_list = [time.time() - start]
        else:
            cur_list.append(time.time() - start)
        timing_stats[key] = cur_list
        return 0
    return wrapper



---------------------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning" Packt Publishing ISBN 978-1-78712-238-3

Sunday, August 20, 2023

Automate Medical Coding Using BERT

Target audience: Beginner
Estimated reading time: 5'
Transformers and self-attention models are increasingly taking center stage in the NLP toolkit of data scientists [ref 1]. This article delves into the design, deployment, and assessment of a specialized transformer tasked with extracting medical codes from Electronic Health Records (EHR) [ref 2]. The focus is on curbing development and training expenses while ensuring the model remains current.


Table of contents
Introduction
       Extracting medical codes

       Minimizing costs

       Keeping models up-to-date

Architecture

Tokenizer

BERT encoder

       Context embedding

       Segmentation

       Transformer

       Self-attention

Classifier

Active learning

References


Follow me on LinkedIn
Important notes
  • This piece doesn't serve as a primer or detailed account of transformer-based encoders,  Bidirectional Encoder Representations from Transformers (BERT), multi-label classification or active learning. Detailed and technical information on these models is available in the References section. [ref 1, 3, 8, 12]. 
  • The terms medical document, medical note and clinical notes are used interchangeably
  • Some functionalities discussed here are protected intellectual property, hence the omission of source code.


Introduction

Autonomous medical coding refers to the use of artificial intelligence (AI) and machine learning (ML) technologies to automatically assign medical codes to patient records [ref 4]. Medical coding is the process of assigning standardized codes to diagnoses, medical procedures, and services provided during a patient's visit to a healthcare facility. These codes are used for billing, reimbursement, and research purposes.


By automating the medical coding process, healthcare organizations can improve efficiency, accuracy, and consistency, while also reducing costs associated with manual coding.

 

A health insurance claim is an indication of the service given by a provider, even though the medical records associated with this service can greatly vary in content and structure. It's crucial to precisely extract medical codes from clinical notes since outcomes, like hospitalizations, treatments, or procedures, are directly tied to these diagnostic codes. Even if there are minor variations in the codes, claims can still be valid for specific services, provided the clinical notes, patient history, diagnosis, and advised procedures align.


fig. 1 Extraction of knowledge, predictions from electronic medical records 

Medical coding is the transformation of healthcare diagnosis, procedures, medical services described in electronic health records, physician's notes or laboratory results into alphanumeric codes.  This study focuses on automated generation of medical codes and health insurance claims from a given clinical note or electronic health record.

Challenges

There are 3 issues to address:
  1. How to extract medical codes reliably, given that labeling of medical codes is error prone and the clinical documents are very inconsistent?
  2. How to minimize the cost of self- training complex deep models such as transformers while preserving an acceptable accuracy?
  3. How to continuously keep models up to date in production environment?

Extracting medical codes

Medical codes are derived from patient records and clinical notes to forecast procedural results, determine the length of hospital stays, or generate insurance claims. The most prevalent medical coding systems include:
  • International Classification of Diseases (ICD-10) for diagnosis (with roughly 72,000 codes)
  • Current Procedural Terminology (CPT) for procedures and medications (encompassing around 19,000 codes)
  • Along with others like Modifiers, SNOMED, and so forth.
The vast array of medical codes poses significant challenges in extraction due to:
  • The seemingly endless combinations of codes linked to a specific medical document
  • Varied and inconsistent formats of patient records (in terms of terminology, structure, and length.
  • Complications in gleaning context from medical information systems.

Minimizing costs

A study on deep learning models suggests that training a significant language model (LLM) results in the emission of 626,155 pounds of CO2, comparable to the total emissions from five vehicles over their lifespan.

To illustrate, GPT-3/ChatGPT underwent training on 500 billion words with a model size of 175 billion parameters. A single training session would require 355 GPU-years and bear a cost of no less than $4.6M. Efforts are currently being made to fine-tune resource utilization for the development of upcoming models [ref 5].

Keeping models up-to-date

Customer data in real-time is continuously changing, often deviating from the distribution patterns the models were originally trained on (due to concept and covariate shifts).
This challenge is particularly pronounced for transformers that need task-specific fine-tuning and might even necessitate restarting the pre-training process — both of which are resource-intensive actions.

Architecture

To tackle the challenges highlighted earlier, the proposed solution should encompass four essential AI/NLP elements:
  • Tokenizer to extract tokens, segments & vocabulary from a corpus of medical documents.
  • Bidirectional Encoder Representations from Transformers (BERT) to generate a representation (embedding) of the documents [ref 3].
  • Neural-based classifier to predict a set of diagnostic codes or insurance claim given the embeddings.
  • Active/transfer learning framework to update model through optimized selection/sampling of training data from production environment.
From a software engineering perspective, the system architecture should provide a modular integration capability with current IT infrastructures. It also requires an asynchronous messaging system with streaming capabilities, such as Kafka, and REST API endpoints to facilitate testing and seamless production deployment.

fig. 2  Architecture for integration of AI components with external medical IT systems 


Tokenizer 

The effectiveness of a transformer encoder's output hinges on the quality of its input: tokens and segments or sentences derived from clinical documents. Several pressing questions need addressing:

  1. Which vocabulary is most suitable for token extraction from these notes? Do we consider domain-specific terms, abbreviations, Tf-Idf scores, etc.?
  2. What's the best approach to segmenting a note into coherent units, such as sections or sentences?
  3. How do we incorporate or embed pertinent contextual data about the patient or provider into the encoder?
Tokens play a pivotal role in formulating a dynamic vocabulary. This vocabulary can be enriched by incorporating words or N-grams from various sources like:
  • Terminology from the American Medical Association (AMA)
  • Common medical terms with high TF-IDF scores
  • Different senses of words
  • Abbreviations
  • Semantic descriptions
  • Stems
  • .....

fig. 3 Generation of a vocabulary using training corpus and knowledge base

Our optimal approach is based on utilizing uncased words from the American Medical Association, coupled with the top 85% of terms derived from training medical notes, ranked by their highest TF-IDF scores. It's worth noting that this method can be resource-intensive.

BERT encoder

In NLP, words and documents are represented in the form of numeric vectors allowing similar words to have similar vector representations [ref 6].
The objective is to generate embeddings for medical documents including contextual data to be feed into a deep learning classifier to extract diagnostic codes or generate a medical insurance claim [ref 7].

Context embedding 

Contextual information such as patient data (age, gender,...), medical service provider, specialty, or location is categorized (or bucked for continuous values) and added to the tokens extracted from the medical note. 

Segmentation

Structuring electronic health records into logical or random groups of segments/sentences presents a significant challenge. Segmentation involves dividing a medical document into segments (or sections), each with an equal number of tokens that consist of sentences and relevant contextual data.

Several methods can be employed to segment a document:
  1. Isolating the contextual data as a standalone segment.
  2. Integrating the contextual data into the document's initial segment.
  3. Embedding the contextual data into any arbitrarily chosen segment [Ref 6].

fig. 4 Embedding of medical note with contextual data using 2 segments


Our study show the option 2 provides the best embedding for the feed forward neural network classifier.
Interestingly, treating the entire note as a single sentence and using the AMA vocabulary leads to diminished accuracy in subsequent classification tasks.

Transformer

We employ the self-supervised Bidirectional Representation for Transformer (BERT) with the objectives to:
  • Grasp the contextual significance of medical phrases.
  • Create embeddings/representations that merge clinical notes with contextual data.
The model construction involves two phases:
  1. Pretraining on an extensive, domain-specific corpus [ref 8].
  2. Fine-tuning tailored for specific tasks, like classification [ref 9].

After the pretraining phase concludes, the document embedding is introduced to the classifier training. This can be sourced:
  1. Directly from the output of the pretrained model (document embeddings).
  2. During the fine-tuning process of the pretrained model. Concurrently, fine-tuning operates alongside active learning for model updates."\


fig. 5 Model weights update with features extraction vs fine tuning

It's strongly advised to utilize one of the pretrained BERT models like ClinicalBERT [ref 10] or GatorTron [ref 11], and then adapt the transformer for classification purposes. However, for this particular project, we initiated BERT's pretraining on a distinct set of clinical notes to gauge the influence of vocabulary and segmentation on prediction accuracy.


Self-attention

Here's a concise overview of the multi-head self-attention model for context:
The foundation of a transformer module is the self-attention block that processes token, position, and type embeddings prior to normalization. Multiple such modules are layered together to construct the encoder. A similar architecture is employed for the decoder.


fig. 6 Schematic for transformer encoder block

Classifier

The classifier is structured as a straightforward feed-forward neural network (fully connected), since a more intricate design might not considerably enhance prediction accuracy. In addition to the standard hyper-parameter optimization, different network configurations were assessed.
The network's structure, including the number and dimensions of hidden layers, doesn't have a significant influence on the overall predictive performance.


Active learning

The goal is to modify models to tackle the issue of covariate shifts observed in the distribution of real-time/production data during inference.

The dual-faceted approach involves:
  1. Selecting data samples with labels that deviate from the distribution initially employed during training (Active learning) [ref 12].
  2. Adjusting the transformer for the classification objective using these samples (Transfer learning)
A significant obstacle in predicting diagnostic codes or medical claims is the steep labeling expense. In this context, learning algorithms can proactively seek labels from domain experts. This iterative form of supervised learning is known as active learning.
Because the learning algorithm selectively picks the examples, the quantity of samples needed to grasp a concept is frequently less than that required in traditional supervised learning. In this aspect, active learning parallels optimal experimental design, a standard approach in data analysis [ref 13].


fig. 6 Simplified data pipeline for active learning.

In our scenario, the active learning algorithm picks an unlabeled medical note, termed note-91, and sends it to a human coder who assigns it the diagnostic code S31.623A. Once a substantial number of notes are newly labeled, the model undergoes retraining. Subsequently, the updated model is rolled out and utilized to forecast diagnostic codes on notes in production.

Thank you for reading this article. For more information ...

References


A formal presentation of this project is available at


Glossary

  • Electronic health record (EHR):  An Electronic version of a patients medical history, that is maintained by the provider over time, and may include all of the key administrative clinical data relevant to that persons care under a particular provider, including demographics, progress notes, problems, medications, vital signs, past medical history, immunizations, laboratory data and radiology reports.
  • Medical document: Any medical artifact related to the health of a patient. Clinical note, X-rays, lab analysis results,...
  • Clinical note: Medical document written by physicians following a visit. This is a textual description of the visit, focusing on vital signs, diagnostic, recommendation and follow-up.
  • ICD (International Classification of Diseases):  Diagnostic codes that serve a broad range of uses globally and provides critical knowledge on the extent, causes and consequences of human disease and death worldwide via data that is reported and coded with the ICD. Clinical terms coded with ICD are the main basis for health recording and statistics on disease in primary, secondary and tertiary care, as well as on cause of death certificates
  • CPT (Current Procedural Terminology):  Codes that offer health care professionals a uniform language for coding medical services and procedures to streamline reporting, increase accuracy and efficiency. CPT codes are also used for administrative management purposes such as claims processing and developing guidelines for medical care review.


---------------------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning" Packt Publishing ISBN 978-1-78712-238-3