Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
1.7 kB
2
Indexable
Never
def find_best_split(data):
  """
  This function will find the best split combination of data
  args:
  * data(type: DataFrame): the input data
  return
  * best_ig(type: float): the best information gain you obtain
  * best_threshold(type: float): the value that splits data into 2 branches
  * best_feature(type: string): the feature that splits data into 2 branches
  """
  best_ig, best_threshold, best_feature = 0.0,0.0,''
  for feature in data:
    sorted_data = data.sort_values([feature], ascending = True)
    
    # print(feature)
    # print(sorted_data.head(2).index[1])
    # if feature == 'glucose_apache':
    #   print(sorted_data['glucose_apache'])
    #   print(sorted_data['diabetes_mellitus'])

    temp_mask = np.zeros((int(data.shape[0]), 1), dtype=bool)
    for i in range(0,data.shape[0]):
      id = sorted_data.head(int(data.shape[0])).index[i]
      if id < data.shape[0]-1 and sorted_data[feature][id] == sorted_data[feature][id+1]:
         continue
      temp_mask[ id ] = True
      mask = pd.DataFrame(temp_mask, columns=['mask'])

      ig = information_gain(sorted_data, mask['mask'])
      # print(feature, i, sorted_data[feature][i], ig)
      if ig > best_ig:
        best_ig = ig
        best_threshold = sorted_data[feature][id]
        best_feature = feature
      # 15/5 2/8
  
  return best_ig, best_threshold, best_feature


# [Note] You have to save the value of "ans_ig", "ans_value", and "ans_name" into the output file
ans_ig, ans_value, ans_name = find_best_split(input_data)
print("ans_ig = ", ans_ig)
print("ans_value = ", ans_value)
print("ans_name = ", ans_name)