Untitled
unknown
python
3 years ago
1.7 kB
4
Indexable
def find_best_split(data): """ This function will find the best split combination of data args: * data(type: DataFrame): the input data return * best_ig(type: float): the best information gain you obtain * best_threshold(type: float): the value that splits data into 2 branches * best_feature(type: string): the feature that splits data into 2 branches """ best_ig, best_threshold, best_feature = 0.0,0.0,'' for feature in data: sorted_data = data.sort_values([feature], ascending = True) # print(feature) # print(sorted_data.head(2).index[1]) # if feature == 'glucose_apache': # print(sorted_data['glucose_apache']) # print(sorted_data['diabetes_mellitus']) temp_mask = np.zeros((int(data.shape[0]), 1), dtype=bool) for i in range(0,data.shape[0]): id = sorted_data.head(int(data.shape[0])).index[i] if id < data.shape[0]-1 and sorted_data[feature][id] == sorted_data[feature][id+1]: continue temp_mask[ id ] = True mask = pd.DataFrame(temp_mask, columns=['mask']) ig = information_gain(sorted_data, mask['mask']) # print(feature, i, sorted_data[feature][i], ig) if ig > best_ig: best_ig = ig best_threshold = sorted_data[feature][id] best_feature = feature # 15/5 2/8 return best_ig, best_threshold, best_feature # [Note] You have to save the value of "ans_ig", "ans_value", and "ans_name" into the output file ans_ig, ans_value, ans_name = find_best_split(input_data) print("ans_ig = ", ans_ig) print("ans_value = ", ans_value) print("ans_name = ", ans_name)
Editor is loading...