While Python is predominantly recognized as the go-to programming language for data science, leveraging Java-based frameworks can offer substantial advantages, especially for rapid distributed inference.
- This article doesn't delve into the specifics of Apache Spark, Kafka, Deep Java Library, or BERT individually. Instead, it focuses on how these components are integrated to create an efficient solution for inference tasks.
- Development environments: JDK 11, Scala 2.12.15, Apache Spark 3.3.1, Apache Kafka 2.8.0, Deep Java Library 0.20.0
- Comments and ancillary code are omitted for the sake of clarity.
- Source code available at https://github.com/patnicolas/bertspark
Combining the best of both worlds
- Python's limited capacity for task parallelization, whether through concurrent threads or distributing tasks across a network.
- Commercial applications often depend on web services running on Java Virtual Machine (JVM) and make extensive use of Apache's open-source libraries.
Java/Scala for inference
Typically, the process involves creating models in a Python environment like Jupyter, an IDE, or Anaconda, and then saving the model parameters. DJL then takes over by loading these saved parameters and initializing the inference model, which is then ready to handle runtime requests.
Distributed inference pipeline
By integrating these two technologies, we can enhance the scalability of predictions by parallelizing the execution of deep learning models. The critical components of this distributed inference pipeline include:
- Apache Spark: This tool segments runtime requests for predictions into batches. These batches are then processed concurrently across remote worker nodes.
- Apache Kafka: This acts as an asynchronous messaging queue, effectively separating the client application from the inference pipeline, ensuring smooth data flow without bottlenecks.
- Deep Java Library (DJL): It connects with the binary executables of the deep learning models.
- Kubernetes: This system containerizes the instances of the inference pipelines, facilitating scalable and automated deployment. Notably, Spark version 3.2 and later versions offer direct integration with Kubernetes.
- Deep Learning Frameworks: This includes well-known frameworks like TensorFlow, MXNet, and PyTorch, which are part of the overall architecture.
The two main benefits of such pipeline are simplicity (all tasks/processes run on JVM) and low latency.
Note: Spark and DJL can also be used in the training phase to distribute the training of a mini batch.
Apache Kafka
First ,we construct the handler class, KafkaPrediction that
- consumes requests from Kafka topic consumeTopic
- invokes the prediction model and transformation, predictionPipeline
- produces prediction into Kafka topic produceTopic
class KafkaPrediction(
consumeTopic: String,
produceTopic: String,
predictionPipeline: Seq[Request] => Seq[Prediction]) {
// 1 - Constructs the transform of Kafka messages for prediction
val transform = (requestMsg: Seq[RequestMessage]) => {
// 2- Invoke the execution of the pipeline
val predictions = predictionPipeline(requestMsg.map(_.requestPayload))
predictions.map(ResponseMessage(_))
}
// 3- Build the Kafka consumer for prediction request
val consumer = new KafkaConsumer[RequestMessage](
RequestSerDe.deserializingClass,
consumeTopic
)
// 4- Build the Kafka producer for prediction response
val producer = new KafkaProducer[ResponseMessage](
ResponseSerDe.serializingClass,
produceTopic
)
.....
}
- We first need to create a wrapper function, transform to generate a prediction. The function converts a request message of type RequestMessage into a prediction of type ResponseMessage.
- The wrapper, transform invoke the prediction pipeline predictionPipeline after converting the messages of type RequestMessage consumed from Kafka into actual request (Request). The predictions are converted into message of type ResponseMessage produced to Kafka
- The consumer is fully defined by the de-serialization of data consumed from Kafka and its associated topic
- The producer serialized the response back to Kafka service.
def executeBatch(
consumeTopic: String,
produceTopic: String,
maxNumResponses: Int): Unit = {
// 1 - Initialize the prediction pipeline
val kafkaHandler = new KafkaPrediction(
consumeTopic,
produceTopic,
predictionPipeline
)
while(running) {
// 2 - Pool the request topic (has its own specific Kafka exception handler)
val consumerRecords = kafkaHandler.consumer.receive
if(consumerRecords.nonEmpty) {
// 3 - Generate and apply transform to the batch
val input: Seq[RequestMessage] = consumerRecords.map(_._2)
val responses = kafkaHandler.predict(input)
if(responses.nonEmpty) {
// 4 - Produce to the output topic
val respMessages = responses.map(
response =>(response.payload.id, response)
)
// 5- Produce the batch of response messages to Kafka
kafkaHandler.producer.send(respMessages)
// 6 - Get confirmation from Kafka has indeed processed the response
kafkaHandler.consumer.asyncCommit
}
else
logger.error("No response is produced to Kafka")
}
kafkaHandler.close
}
- First we instantiate the Kafka message handler class, KafkaPrediction we created earlier
- At regular interval, we pull a batch of new requests from Kafka
- If the batch is not empty, we invoke the handler, predict to the prediction models
- Once done, we encapsulate the predictions into the ResponseMessage instances
- The messages are produced into the producer topic in the Kafka queue
- Finally, Kafka acknowledges the correct reception of the responses, asynchronously.
Apache Spark
Leveraging Spark data set and partitioning is surprisingly simple.
def predict(
requests: Seq[Request]
)(implicit sparkSession: SparkSession): Seq[Prediction] = {
import sparkSession.implicits._
// 1 - Convert request into a Spark data set
val requestDataset = requests.toDS()
// 2 - Execute the prediction by invoking the DJL model
val responseDataset: Dataset[Prediction] = requestDataset(predict(_))
// 3 - Convert Spark data set response
responseDataset.collect()
}
- Once the spark session (context) is initiated, the batch of requests is converted into a data set, requestDataset
- Spark applies the prediction model (DJL) on each request on the partitioned data
- Finally, the predictions are collected from the Spark worker nodes before been returned to the Kafka handler
Note: The Spark context is assumed to be created and passed as implicit parameter to the prediction method.
Deep Java Library
DJL's capability to adapt to any hardware setup (be it CPU or GPU) and its integration with big data frameworks position it as an ideal choice for a high-performance distributed inference engine [ref 4]. The library is particularly well-suited for constructing transformer encoders like BERT or GPT, as well as decoders such as GPT and ChatGPT.
In this setup, the input tensors are processed by the deep learning models on a GPU. Importantly, the data is allocated in the native memory space, which is external to the JVM and its garbage collector. The DJL library supports native tensor types such as NDArray and lists of tensors like NDList, along with a straightforward memory management tool, NDManager.
The classifier operates on the Spark worker node. The following code snippet, though a simplified version, illustrates the steps involved in invoking a BERT-based classifier using the DJL framework.
class BERTClassifier(
minTermFrequency: Int,
path: Path)(implicit sparkSession: SparkSession) {
// 1 - Manage tensor allocation as NDArray
val ndManager = NDManager.newManager()
// 2 - Define the configuration of the classifier
val classifyCriteria: Criteria[NDList, NDList] = Criteria.builder()
.optApplication(Application.UNDEFINED)
.setTypes(classOf[NDList], classOf[NDList])
.optOptions(options)
.optModelUrls(s"file://${path.toAbsolutePath}")
.optBlock(classificationBlock)
.optEngine(Engine.getDefaultEngineName())
.optProgress(new ProgressBar())
.build()
// 3- Load the model from a local file
val thisModel = classifyCriteria.loadModel()
// 4 - Instantiate a new predictor
val predictor = thisModel.newPredictor()
// 5 - Execute this request on this worker node
def predict(requests: Request): Prediction = {
predictor.predict(ndManager, requests)
}
// 6- Close resources
def close(): Unit = {
model.close()
predictor.close()
ndManager.close()
}
}
- Set the manager for tensor in native memory
- Configure the classifier with its related neural block (classificationBlock)
- Load the model (MXNet, PyTorch or TensorFlow) from local file
- Instantiate a predictor from the model
- Submit the request to the DL model and return a prediction
- Close all the resources allocated in the native memory at the end of the run
Use case: BERT
Architecture
- Text processor (Tokenizer, Document segmentation,...)
- Pre-trained BERT
- Fully-connected neural network classifier (supervised)
A transformer model consists of two main components: an encoder and a decoder. The encoder's role is to convert sentences and paragraphs into an internal format, typically a numerical matrix, that captures the context of the input. Conversely, the decoder interprets and reverses this process. When combined, the encoder and decoder enable the transformer to execute sequence-to-sequence tasks like translation. Interestingly, isolating the encoder part of the transformer provides insights into the context, enabling various intriguing applications.
BERT has been applied to various problems including the automation of medical coding [ref 5]
Neural blocks
- Transformer, self-attention block with token, position and sentence order embeddings
- Masked Language Model (MLM) block
- Next Sentence Prediction (NSP) block
class CustomPretrainingBlock (
bertModelType: String
activationType: String,
vocabularySize: Long) extends BaseNetBlock {
// First block: BERT transformer
val bertBlock = getBertConfig(bertModelType)
.setTokenDictionarySize(Math.toIntExact(vocabularySize))
.build
val activationFunc: java.util.function.Function[NDArray, NDArray] =
ActivationConfig.getNDActivationFunc(activationType)
// Second block: Masked Language Model
val bertMLMBlock = new BertMaskedLanguageModelBlock(bertBlock, activationFunc)
// Third: block: Next Sentence Predictor
val bertNSPBlock = new BertNextSentenceBlock
val pretrainingBlocks = new BERTPretrainingBlocks(
("transformer", bertBlock),
("mlm", bertMLMBlock),
("nsp", bertNSPBlock)
)
override protected def forwardInternal(
parameterStore: ParameterStore,
inputNDList: NDList,
training : Boolean,
params: PairList[String, java.lang.Object]): NDList
def getBertConfig(bertModelType: String): BertBlock.Builder = bertModelType match {
case `nanoBertLbl` =>
// 4 encoders, 4 attention heads, embedding size: 256, dimension 256x4
BertBlock.builder().nano()
case `microBertLbl`=>
// 12 encoders,8 attention heads, embedding size: 512, dimension 512x4
BertBlock.builder().micro()
case `baseBertLbl` =>
// 12 encoders,12 attention heads, embedding size: 768, dimension 768x4
BertBlock.builder().base()
case `largeBertLbl` =>
// 24 encoders,16 attention heads, embedding size: 1024, dimension 1024x4
BertBlock.builder().large()
case _ =>
}
References
[1] BiDirectional Encoder Representations from Transformer[3] Apache Spark
Appendix
override protected def forwardInternal(
parameterStore: ParameterStore,
inputNDList: NDList,
training : Boolean,
params: PairList[String, java.lang.Object]): NDList = {
// Dimension batch_size x max_sentence_size
val tokenIds = inputNDList.get(0)
val typeIds = inputNDList.get(1)
val inputMasks = inputNDList.get(2)
// Dimension batch_size x num_masked_token
val maskedIndices = inputNDList.get(3)
try {
val ndChildManager = NDManager.subManagerOf(tokenIds)
ndChildManager.tempAttachAll(inputNDList)
// Step 1: Process the transformer block for Bert
val bertBlockNDInput = new NDList(tokenIds, typeIds, inputMasks)
val ndBertResult = transformerBlock.forward(parameterStore, bertBlockNDInput, training)
// Step 2 Process the Next Sentence Predictor block
// Embedding sequence dimensions are batch_size x max_sentence_size x embedding_size
val embeddedSequence = ndBertResult.get(0)
val pooledOutput = ndBertResult.get(1)
// Need to un-squeeze for batch size =1, (embedding_vector) => (1, embedding_vector)
val unSqueezePooledOutput =
if(pooledOutput.getShape.dimension() == 1) {
val expanded = pooledOutput.expandDims(0)
ndChildManager.tempAttachAll(expanded)
expanded
}
else
pooledOutput
// We compute the NSP probabilities in case there are more than one single sentences
val logNSPProbabilities: NDArray =
bertNSPBlock.forward(parameterStore, new NDList(unSqueezePooledOutput), training)
.singletonOrThrow
// Step 3: Process the Masked Language Model block
// Embedding table dimension are vocabulary_size x Embeddings size
val embeddingTable = transformerBlock
.getTokenEmbedding
.getValue(parameterStore, embeddedSequence.getDevice, training)
// Dimension: (batch_size x maskSize) x Vocabulary_size
val logMLMProbabilities: NDArray = bertMLMBlock
.forward(
parameterStore,
new NDList(embeddedSequence, maskedIndices, embeddingTable),
training)
.singletonOrThrow
// Finally build the output
val ndOutput = new NDList(logNSPProbabilities, logMLMProbabilities)
ndChildManager.ret(ndOutput)
}
catch { ... }
}
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning" Packt Publishing ISBN 978-1-78712-238-3