Thursday, June 15, 2023

Automate Technical Documentation using LLM

Target audience: Beginner
Estimated reading time: 3'

It can be challenging for a data engineer or data scientist to produce, update and maintain the documentation for a project. This article presents the idea of "latent technical documentation" which utilizes tags on software development items (or artifacts), combined with a large language model (LLM), to develop, refine, and maintain a project's documentation.


Table of contents
Follow me on LinkedIn

Notes
  • The post describes the simple use of large language models. It is not an in-depth description or even an introduction to LLM. 
  • ChatGPT API is introduced in two of my previous posts: Secure ChatGPT API client in Scala and ChatGPT API Python client.
  • To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statements.

Overview

Challenges

The key challenges for documenting an AI-based project to be deployed in production are
  • Information spread across multiple platforms and devices.
  • Uneven documentation from contributors, experts with different background; Data engineers, DevOps, Scientists, Product managers, ...
  • Sections of the documentation being out-of-date following recent change in product or service requirements.
  • Missing justification for design or modeling decision.

What is latent documentation?

Latent technical documentation is a two-step process:
  1. Tagger: Insert comments or document fragments related to a project for each artifact, item or step in the development process (coding, architecture design, unit testing, version control commits, deployment scripts, configurations containers definition and orchestration, ...).
  2. Generator: Gather and consolidate various doc fragments into a single pre-formatted document for the entire project. A large language model (LLM) is an adequate tool to generate a clear and concise documentation.
The following diagram illustrates the tagging-generation process.

Illustration of two step latent documentation


Tagging artifacts is accomplished by defining an easy to use format that does not add overhead in the development cycle.
In this post we use the following format to select and tag relevant information in an artifact:
    ^#tag_key comments^

Example^#KafkaPipelineStream Initiate the processing of streams given properties assigned to the Kafka streams, in config/kafka.conf^

The second step, generation of the project document consists of collecting, parsing tags across the various artifacts, then generate a summary or/and formal document for the project. 
Let's review some of the artifact tags.

Tagging artifacts

The engineers, data scientists tags,  key & comment, for as many as possible artifacts used in the development and in the case of AI, training, and validation of models. A partial list of artifacts:
  • Source code files
  • Deployment scripts
  • Version control comments and logs
  • Unit tests objectives
  • Test results
  • Orchestration libraries such Airflow
  • Container-based frameworks such as Docker, Kubernetes or TerraFlow
  • Product requirement documents (PRD, MRD)
  • Minutes of meetings.

Python code

Documentation can be extracted from Python source code by selecting and tagging section of the comments. This process does not add much overhead to the development cycle as the contentious developers document their code, anyway.

async def post(self) -> list:
    """
       ^#AsyncHTTP Process the list of iterator (1 iterator per client). 
       The steps are:
        1. Create a tasks from co-routine
        2. Aggregate the various tasks
        3. Block on the completion of all the tasks^
        :return: List of results
    """
    tasks = self.__create_tasks()      
    all_tasks = asyncio.gather(*tasks) 
    responses = await all_tasks       

    assert self.num_requests == len(responses), \
            f'Number of responses {len(responses)} != number of requests {self.num_requests}'
    
    return responses


In this code snippet, the documentation fragment "Process the list ....    of all the tasks" will be associated with the key AsyncHTTP.

Scala code

The following code snippet define a tag with key KafkaPipelineStream for the class constructor PipelineStreams and method start.

/**
 * ^#KafkaPipelineStream Parameterized basic pipeline streams that consumes requests.
    using Kafka stream.The topology is created from the request and response topic.
    This class inherits from PipelineStreams.^
 * @param valueDeserializerClass Class or type used in the deserialization for Kafka consumer
 * @tparam T Type of Kafka message consumed
 * @see org.streamingeval.kafka.streams.PipelineStreams
 */
private[kafka] abstract class PipelineStreams[T](valueDeserializerClass: String) {
  protected[this] val properties: Properties = getProperties
  protected[this] val streamBuilder: StreamsBuilder = new StreamsBuilder

  /**
   * ^#KafkaPipelineStream Initiate the processing of streams given properties assigned
      to the Kafka streams, in config/kafka.conf^
   * @param requestTopic Input topic for request (Prediction or Feedback)
   * @param responseTopic Output topic for response
   */
  def start(requestTopic: String, responseTopic: String): Unit =
    for {
      topology <- createTopology(requestTopic, responseTopic)
    } 
    yield {
      val streams = new KafkaStreams(topology, properties)
      streams.cleanUp()
      streams.start()

      logger.info(s"Streaming for $requestTopic requests started!")
      val delayMs = 2000L
      delay(delayMs)

      // Shut down the streaming
      sys.ShutdownHookThread {
          streams.close(Duration.ofSeconds(12))
      }
    }
}

GitHub commits

Documentation can be augmented by tagging the comment(s) to a version control  commit request. The following command line add comment for the key KafkaPipelineStream for a commit.

git commit -m "^#KafkaPipelineStreams Implementation of streams using RequestSerDe and R
esponseSerDe serialization-deserialization pairs^ for parameterized requests and responses" .

Airflow DAG & tasks

Here is an example of tagging section of comments on a Airflow Direct Acyclic Graph (DAG) of executable tasks, with the same key KafkaPipelineStream.

default_args = {
    'owner': 'herold',
    'retries': 3,
    'retry_delay': timedelta(minutes=10)
}


@dag(dag_id='produce_note_from_s3',
     default_args=default_args,
     start_date=datetime(2023, 4, 12),
     schedule_interval='@hourly')


"""
    ^#KafkaPipelineStream Definition of the DAG to load unstructured medical 
    documents from AWS S3. It relies on the loader function, 
    s3_loader defined in module kafka.util.^
"""
def collect_from_s3_etl():

    @task()
    def load_from_s3():
        return s3_loader()

    produced_notes = ProduceToTopicOperator(
        task_id="loaded_from_s3",
        kafka_config_id="kafka_default",
        topic=KAFKA_TOPIC,
        producer_function=loader_notes,
        producer_function_args=["{{ ti.xcom_pull(task_ids='load_from_s3')}}"],
        poll_timeout=10,
    )
    
    produced_notes()

Docker compose

Comments and tags can be also added to container application development such as Docker or a container orchestrator like Kubernetes. 
The following multi-container descriptor, docker-compose.yml uses KafkaPipelineStream tag to add information regarding application deployment configuration to the project documentation.

version: '0.1'
networks:
    datapipeline:
        driver: bridge

services:
    zookeeper:
        # .... image and environment

        # ^#KafkaPipelineStream Kafka docker image loaded from bantam following zookeeper deployment
        # Port 29092
        # Consumer properties
        # KAFKA_CONSUMER_CONFIGURATION_POOL_TIME_INTERVAL: 14800
        # KAFKA_CONSUMER_CONFIGURATION_MAX_POLL_RECORDS: 120
        # KAFKA_CONSUMER_CONFIGURATION_FETCH_MAX_BYTES: 5428800
        # KAFKA_CONSUMER_CONFIGURATION_MAX_PARTITION_FETCH_BYTES: 1048576^
    kafka:
        image: bitnami/kafka:latest
        container_name: "Kafka"
        restart: always
        depends_on:
            - zookeeper
        ports:
            - 29092:29092
        environment:
            KAFKA_BROKER_ID: 1
            KAFKA_ZOOKEEPER_CONNECT: zookeeper:2181
            KAFKA_ADVERTISED_LISTENERS: PLAINTEXT://kafka:9092, PLAINTEXT_HOST://localhost:29092
            KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: PLAINTEXT:PLAINTEXT,PLAINTEXST_HOST:PLAINTEXT
            KAFKA_INTER_BROKER_LISTENER_NAME: PLAINTEXT
            KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: 1
            KAFKA_CONSUMER_CONFIGURATION_POOL_TIME_INTERVAL: 14800
            KAFKA_CONSUMER_CONFIGURATION_MAX_POLL_RECORDS: 120
            KAFKA_CONSUMER_CONFIGURATION_FETCH_MAX_BYTES: 5428800
            KAFKA_CONSUMER_CONFIGURATION_MAX_PARTITION_FETCH_BYTES: 1048576
        volumes:
            - ./producer:/producer
            - ./consumer:/consumer
        networks:
            - datapipeline



Generating documentation

The next challenge is to collect and generate the documentation. The step for the generation of overall project documents consists of 
  1. Collecting artifact using a script
  2. Extracting tags as key value pairs
  3. Grouping the various documentation fragments per key
  4. Formatting, optionally and forwarding the document to a LLM model.

Large language models

Let's look at generative AI to create a formal, final project document. The process  aggregates the various tag comment  into a single text which is used as the context prompt ('system' role in ChatGPT).

The following document was produced by ChatGPT 4.0 [ref 1]  although alternative large language models could be also used.
Please refer to the implementation of client to ChatGPT in Scala Secure ChatGPT API client in Scala and in Python ChatGPT API Python client.

1. Overview of PipelineStreams Class:
- The class inherits from PipelineStreams.
- It sets up parameterized basic pipeline streams that consume requests using Kafka stream.
- The topology is derived from both the request and response topics.


2. Kafka Streams Configuration:
- The processing of streams is initiated based on properties assigned to the Kafka streams.
- These properties are located in config/kafka.conf.

3. Implementation Details:
- Streams are implemented using RequestSerDe and ResponseSerDe serialization-deserialization pairs.

4. Loading Medical Documents from AWS S3:
- A Directed Acyclic Graph (DAG) is defined to load unstructured medical documents.
- The loading relies on the s3_loader function.
- The s3_loader function is defined in the kafka.util module.

5. Kafka Docker Deployment:
- Kafka docker image is sourced from bitnami.
- It follows a zookeeper deployment.
- The default port is 29092.

6. Kafka Consumer Properties:
- KAFKA_CONSUMER_CONFIGURATION_POOL_TIME_INTERVAL: 14800
- KAFKA_CONSUMER_CONFIGURATION_MAX_POLL_RECORDS: 120
- KAFKA_CONSUMER_CONFIGURATION_FETCH_MAX_BYTES: 5428800
- KAFKA_CONSUMER_CONFIGURATION_MAX_PARTITION_FETCH_BYTES: 1048576


The LLM produces a document of quality based on the quality of the prompt you provide. So, you must craft the tags used as context in the LLM request meticulously.

The maximum token limit for an LLM prompt (ChatGPT 4.0: 8192, Llama-2: 4096, Llama-code: 16384) can constrain the quantity of tagged information used to create the project document.
Using Retrieval-Augmented Generation (RAG), you can bypass the token restriction by defining the various tag inputs as embedded vectors and storing them in a vector database.

Retrieval-Augmented Generation (RAG)

Retrieval augmented generation is a more sophisticated leverage of large language models (LLMs). This is a machine learning framework that relies on an external knowledge base to improve the accuracy of LLMs [ref 2]. The knowledge base contains up to date, domain specific information.

In our case, the knowledge base is built by defining questions (tags) and expected output documentation.

Thank you for reading this article. For more information ...

References



---------------------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning" Packt Publishing ISBN 978-1-78712-238-3

Wednesday, May 31, 2023

Pythonic Coding Exercises: Recursion

Target audience: Intermediate
Estimated reading time: 4'
The most effective way to truly grasp a programming language is by diving in and tackling unique algorithm challenges. In this article, we'll explore common uses of recursion, taking full advantage of Python's dynamic capabilities.

Let's look at some coding gems and have the fun begin!!

Table of contents

       Find a line with max number of points

Follow me on LinkedIn

What you will learn: How to apply recursion in Python to solve search, optimization and computational problems.

Note: The implementation uses Python 3.11

Introduction

Recursion involves solving a problem by breaking it down into simpler instances of the same problem. Essentially, it's a method where a procedure references itself during its execution.

Tail recursion, a specific kind of recursion, occurs when the function's last action is a recursive call. It's optimized to prevent unnecessary stacking, allowing the recursive calls to delve deeper without overwhelming the call stack.

Implicit recursion is another variant where a function indirectly calls itself, without an overt recursive statement

Coding exercises

Find a line with max number of points

ProblemGiven n points on a 2-dimension plane, find the maximum number of points that lie on the same line.

Solution From the mathematics standpoint, N points with coordinates (x, y) are aligned if slope dy/dx between each of these N points are the same (i.e. (1, 1), (3, 3) and (6, 6) belongs to line y = 3.x). This implementation uses recursion. We define a data class Point with basic == and slope computation methods. Time complexity O(n).

Implementation:

from dataclasses import dataclass
from typing import TypeVar, List, Set, Dict
Point = TypeVar("Point")


@dataclass
class Point(object):
    """ 
      Convenient class wrapper to implement the == operator and computation of the derivative 
    """
    x: int
    y: int

    # Override == operator for this class
def __eq__(self, other: Point) -> bool: return self.x == other.x and self.y == other.y  
    # Compute the derivative/sloped between any two points
    def slope(self, other: Point) -> float:
        if self == other:          # To safe guard against self
            return -1.0
        elif self.x == other.x:  # To avoid divide by zero
            return 1e+8
        else:                          # Otherwise compute the derivative
            return float(self.y - other.y)/(self.x - other.x)




def max_num_points(points: List[Point]) -> List[Point]:
  """ 
   Compute the number of points which belongs to the same line (same slope) using recursion
  """
    
   # Execute recursion along the list of points using  list index
   def align(index: int, visited: Dict[float, List[Point]]) -> Dict[float, List[Point]]:
      if index >= len(points) - 1:
         return visited
      else:
        for pt in points:
                
           # Make sure we exclude a data point already visited
           if not points[index] == pt:
              this_slope = points[index].slope(pt)
              # Update or create
              new_visited: List[Point] = visited[this_slope] if this_slope in visited else []

             # Avoid duplicate list of points associated with
              if pt not in new_visited:
                  new_visited.append(pt)

                 # Update the dictionary { slope: list of points with same slope }
                 visited[this_slope] = new_visited

            return align(index+1, visited)
 

 if len(points) == 0:
    return [] 
 else:
   all_slopes_dict = align(0, {})

   # Finally extract the list of points associated with
   # a slope with maximum number of points.
   max_num_pts = -1
   output_points = None
   for slope, points in all_slopes_dict.items():
      if max_num_pts < len(points):
         max_num_pts = len(points)
         output_points = points

   return output_points


Test:
input_points = [
  Point(1, 1), Point(2, 4), Point(3, 3), Point(4, 1), Point(5, 11), Point(6, 6), Point(8, 13)
]

print(str(max_num_points(input_points))) 
    # [Point(x=3, y=3), Point(x=6, y=6), Point(x=1, y=1)]



Optimize coins distribution

ProblemFind the minimum set of coins for a set of coin_types (1 cent, 5 cents  ... 10 cents) required to foot a given bill (meet a target_amount). For example  168 = 100 cents * 1 + 50 cents * 1 + 10  cents * 2 + 5 cent * 1 + 1 cent * 3.

Solution Use recursion, minimizing the number of coins needs to reach the needed amount at each step. The recursion exits when either all the various type of coins have been used or the target amount has been finally reached. The time complexity is O(n)

Implementation:

class CoinsDistribution(object):
    """
        Find the optimal distribution of set of coins (1 cent, 5 cent, 10 cent,..)
        to foot a given bill (i.e. 127 cents = 100 cents*1 + 25 cents*1 + 1 cent* 2
        Recursion on the remaining amount to close once a coin if found
    """
    def __init__(self, coins: []):
        # We do not assume the coins are actually sorted
        self.coin_types = sorted(coins, reverse=True)

    
    def find(self,  target_amount: int) -> List[int]:

        def _find(left_over_amount: int, index: int, coin_distribution: List[int]) -> List[int]:
            # Recursion exit condition (amount reached or no more coin left)
            if left_over_amount == 0 or index >= len(self.coin_types):
                return coin_distribution
            
           else:
                remaining_amount = left_over_amount

                # Attempt to assign as many coins for this category of coins
                while remaining_amount >= self.coin_types[index]:
                    remaining_amount -= self.coin_types[index]
                    coin_distribution[index] += 1
                # Move to the next type of coin
                return _find(remaining_amount, index+1, coin_distribution)

        return _find(target_amount, 0, [0] * len(self.coin_types))


Test:
coin_types = [1, 5, 10, 25, 50, 100]
target_amount = 376
coins_distribution = CoinsDistribution(coin_types)
distribution = coins_distribution.find(target_amount)

acc = [f'{distribution}*{coins_distribution.coin_types[index]}'
           for index, distribution in enumerate(distribution) if distribution > 0]
print(' + '.join(acc))  # 3*100 + 1*50 + 1*25 + 1*1



Test if a list has a cycle

Problem: Find if a list, input_values has a cycle. The problem is equivalent to finding the first duplicate in a list.

Solution: Use two iterators: the second iterator advancing twice as fast the first one. Time complexity O(n).

Implementation:
def has_cycle(input_values: List[int]) -> bool:
  """ Test if a list has a duplicate or a cycle, using two iterator """
  iter1 = iter(input_values)
  iter2 = iter(input_values)

  try:
    while True:
      value1 = next(iter1)
      next(iter2)               # Iter2 advances two elements per iteration
      value2 = next(iter2)
            
      if value1 == value2:
          return True        # Find & exit

  except StopIteration as e:
    return False

Test:
values1 = [2, 4, 6, 3, 11, 7, 9, 14, 17]
values2 = [2, 4, 6, 3, 11, 6, 9, 14, 17]

print(has_cycle(values1))  # False
print(has_cycle(values2))  # True


Intersection of sorted arrays

ProblemFind the intersection of two sorted lists list1 and list2 of integers.

Solution: Map the two lists of integer n to lists of tuples (n, list_index) and push the tuples into a priority queue using heaps module. Record two consecutive items, from two different list popped from the priority queue having the same value. The time complexity is O(long + n) ~ O(n)

Implementation:
def intersect_sorted_lists(list1: List[int], list2: List[int]) -> List[int]:
  """ 
    Extract the intersection of two sorted list of different size using 
    a priority queue and recursion
  """
    
  # Taken care of the basic case
  if list1[-1] < list2[0] or list1[0] > list2[-1]:
     return []

  else:
     # Otherwise convert integers, n into tuple (n, list_index)
     import heapq

     pqueue = []
     [heapq.heappush(pqueue, (item, 0)) for item in list1]
     [heapq.heappush(pqueue, (item, 1)) for item in list2]

     # Recursively popping up tuples from the priority queue
     def intersect(new_tuple: Tuple, tuples_list: List[Tuple]) -> List[Tuple]:
       if len(pqueue) == 0: # Our recursion exit condition
         return tuples_list
       else:
         # Next tuple in the priority queue
         item, index = heapq.heappop(pqueue)
                
         # If two consecutive integers from different list have the same value...
         if new_tuple[1] == item and index != new_tuple[0]:
           tuples_list.append(new_tuple)
         return intersect((index, item), tuples_list)


     first_item, first_index = heapq.heappop(pqueue)
     intersect_tuples = intersect((first_item, first_index), [])
        
     return [item for _, item in intersect_tuples]

Test:
values1 = [0, 2, 8, 10, 12, 34, 46, 48, 54, 99]
values2 = [3, 8, 6, 12, 14, 15, 18, 19, 22, 40, 41, 44, 45, 46, 50, 53]

print(intersect_sorted_lists(values1, values2))  # [8, 12, 46]


Sequence of consecutive integers with highest sum

Problem: Find the sequence of num_items consecutive integers from a given list, input_list which produces the highest sum. For example, extracting 3 consecutive integers with highest sum from the list [2, 1, 5, 9, 3, 2] will produce [5, 9, 3]

Solution: Implements an efficient tail recursion to generate a tuple (sum of the sequence, starting index of the  sequence).

Implementation:
def extract_seq_max_sum(input_list: List[int], num_items: int) -> (int, int):
  """ Extract the sequence of consecutive integers with the maximum summation """

  def _extract_seq_max_sum(
     input_list: List[int],
     num_items: int,
     max_sum: int,
     start_index: int,
     cnt: int) -> (int, int):
        
     # Condition to exit the recursion
     if len(input_list) < num_items:
       return max_sum, start_index
 
     new_sum = sum(input_list[:num_items])
     new_max_sum, new_start_index = \ 
        \(new_sum, cnt) if new_sum > max_sum \ 
         else (max_sum, start_index)
     
     return _extract_seq_max_sum(input_list[1:], num_items, new_max_sum, new_start_index, cnt + 1)

  return _extract_seq_max_sum(input_list, num_items, 0, 0, 0)



Test:
values = [4, 2, 8, 6, 12, 34, 6, 8, 4, 9, 11, 2, 17, 22, 5, 8, 6, 1, 4, 13, 19]

print(extract_seq_max_sum(values, 3))  # 52, 3  [6, 12, 34]
print(extract_seq_max_sum(values, 5))  # 66, 6  [8, 6, 12, 34, 6]


Longest sequence of increasing values

Problem: Extract the longest sequence of increasing values from an arbitrary list of integers. For instance the longest sequence of increasing value in [2, 1, 4, 5, 8, 3, 5, 0, 2] is [4, 5, 8].

Solution: Implements a recursion longest_increasing_seq over the input list of integers by tracking the sequence contains the current value and comparing with the current longest sequence of increasing value. The recursion walks through the input values using an iterator.

Implementation:
def longest_increasing_seq(input_values: List[int]) -> List[int]:
  """ Extract the longest increasing sequence of integer from a list using recursion."""
 
 def longest_increasing_seq(
    input_values_iter,
    cur_increasing_seq: List[int],
    longest_seq: List[int]) -> List[int]:

    while True:
      try:
         # Next element in the list
         next_element = next(input_values_iter)

         # If current increasing list empty or new element > last element
         # add the new element in the current list and continue
         if len(cur_increasing_seq) == 0 or next_element > cur_increasing_seq[-1]:
            cur_increasing_seq.append(next_element)
         
         # Otherwise complete the current increasing sequence and
         # update the longest list if necessary
         else:
            new_longest_seq = cur_increasing_seq.copy() \
               if len(cur_increasing_seq) > len(longest_seq) \
               else longest_seq

            # Re-initialize the current increasing list
            cur_increasing_seq.clear()
            cur_increasing_seq.append(next_element)
          
            # Invoke the next recursion
return longest_increasing_seq(input_values_iter, cur_increasing_seq, new_longest_seq)  
     except StopIteration as e: # Exit point for recursion
        return longest_seq

  return longest_increasing_seq(iter(input_values), [], [])


Test:
values = [6, 1, 4, 9, 11, 22, 17, 8, 6, 1, 4, 13, 19]  
print(str(longest_increasing_seq(values)))      # 1, 4, 9, 11, 22
 

List of items which value equals its index

Problem: Find the integers in a list which value equals its index. For example the algorithm should select the second item 1 from the list  [3, 1, 5, 0].

Solution: Simple traversal with a time complexity O(n).

Implementation:
def get_values_eq_index(input_list: List[int]) -> List[int]:
  """
    Retrieve the list of element for which the value is equal to its index in the list
  """
  match len(input_list):
    case 0:
      return []
    case 1:
      return input_list
    case _:
      return [el for index, el in enumerate(input_list) if index == el]


Test:
input_list = [2, 9, 2, 5, 4, 8, 1, 5, 8, 10]
print(get_values_eq_index(input_list))           # 8


Find first duplicate in a list

Problem: Find the first duplicate in a string.

Solution: Use a set to keep track of visited nodes. Worst case time complexity O(n). This is an example for which recursion is not warranted.

Implementation:
def get_first_duplicate(input_str: AnyStr) -> Optional[AnyStr]:
  """Get the first char duplicate in a string if found, None otherwise"""

  match len(input_str):
    case 0: return None
    case 1: return input_str[0]   # Return the only element
    case _:                                # After dealing with the first special cases
       unique_set = set()
       for ch in input_str:
          if ch in unique_set:
             return ch
          unique_set.add(ch)

       return None

Test:
input_str1 = "helo"
input_str2 = "hello"

print(get_first_duplicate(input_str1))  # None
print(get_first_duplicate(input_str2))  # 'l'


Check if a binary tree is balanced

Problem: Test whether a binary tree defined its node, Node, is balanced. Every path between the root and any leaf should have the same number of nodes.

Solution: Apply an in-order-traversal to collect the number of nodes of all the various lineages, root to leaves then compute the difference of longest and shortest lineages. The traversal of the tree, method, __walk uses recursion. Time complexity O(n). 

Implementation:
class Node(object):
  """ Basic binary tree node """
  def __init__(self, data: int):
    self.left = None
    self.right = None
    self.data = data


  def add(self, data: int) -> NoReturn:
     """
        Simple insertion to a binary tree, appending a new node if necessary
     """
    if data < self.data:
      if self.left:
        self.left.add(data)
      else:
        self.left = BinTreeNode(data)
    elif data > self.data:
        if self.right:
          self.right.add(data)
        else:
          self.right = BinTreeNode(data)
    else:
        self.data = data


  def is_balanced(self) -> bool:
    path_acc = []
    self.__walk('', path_acc)

    if path_acc:
      # Measure the length of the various paths from root to leaves
      path_sizes = [len(x) for x in path_acc if x != '']
      return max(path_sizes) - min(path_sizes) < 1

    else:
      return True


    # ---------   Private helper method ---------

  def __walk(self, path: AnyStr, lineages: List[AnyStr]):
    """ DFS walk """
    if self.left:   # This node has a left child
      self.left.__walk(f'{path}L', lineages)
      lineages.append(path)
        
    if self.right:   # This node has a right childe
      self.right.__walk(f'{path}R', lineages)
      lineages.append(path)


Test:
root = Node(1)
values = [2, 9, 8, 10, 13, 7]
[root.add(n) for n in values]
print(root.is_balanced())        # False

root = Node(10)
values = [6, 14, 4, 8, 12, 16]
[root.add(n) for n in values]
print(root.is_balanced())        # True


Thank you for reading this article. For more information ...

References


---------------------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning" Packt Publishing ISBN 978-1-78712-238-3