Untitled

 avatar
unknown
plain_text
2 years ago
2.5 kB
3
Indexable
def build_tree(data, max_depth, min_samples_split, depth):
  """
  This function will build the decision tree
  args:
  * data(type: DataFrame): the data you want to apply to the decision tree
  * max_depth: the maximum depth of a decision tree
  * min_samples_split: the minimum number of instances required to do partition
  * depth: the height of the current decision tree
  return:
  * subtree: the decision tree structure including root, branch, and leaf (with the attributes and thresholds)
  """
  ig, threshold, feature = 0,0,""
  subtree = {}
  # check the condition of current depth and the remaining number of samples
  if depth < max_depth and data.shape[0] > min_samples_split :
    # call find_best_split() to find the best combination
    ig, threshold, feature = find_best_split(data)

    # check the value of information gain is greater than 0 or not 
    if ig > 0 :
      # update the depth
      depth = depth+1

      # call make_partition() to split the data into two parts
      left, right = make_partition(data, feature, threshold)

      # print(left.shape[0], right.shape[0])

      # If there is no data split to the left tree OR no data split to the left tree
      if left.shape[0] == 0 or right.shape[0] == 0 :
        # return the label of the majority
        label = feature
        return label
      else:
        question = "{} {} {}".format(feature, "<=", threshold)
        subtree = {question: []}

        # call function build_tree() to recursively build the left subtree and right subtree
        left_subtree = build_tree(left, max_depth, min_samples_split, depth)
        right_subtree = build_tree(right, max_depth, min_samples_split, depth)
        if left_subtree == right_subtree:
          subtree = left_subtree
        else:
          subtree[question].append(left_subtree)
          subtree[question].append(right_subtree)
    else:
      # return the label of the majority
      o,z = 0,0
      for i in data['diabetes_mellitus']:
        if i == 1:
          o = o+1
        else:
          z = z+1
          
      label = 0
      if(o>z):
         label = 1
      else:
         label = 0
      return label
  else:
    # return the label of the majority
    o,z = 0,0
    for i in data['diabetes_mellitus']:
      if i == 1:
        o = o+1
      else:
        z = z+1

    label = 0
    if(o>=z):
        label = 1
    else:
        label = 0
    return label

  return subtree