Untitled
unknown
python
3 years ago
1.7 kB
7
Indexable
def find_best_split(data):
"""
This function will find the best split combination of data
args:
* data(type: DataFrame): the input data
return
* best_ig(type: float): the best information gain you obtain
* best_threshold(type: float): the value that splits data into 2 branches
* best_feature(type: string): the feature that splits data into 2 branches
"""
best_ig, best_threshold, best_feature = 0.0,0.0,''
for feature in data:
sorted_data = data.sort_values([feature], ascending = True)
# print(feature)
# print(sorted_data.head(2).index[1])
# if feature == 'glucose_apache':
# print(sorted_data['glucose_apache'])
# print(sorted_data['diabetes_mellitus'])
temp_mask = np.zeros((int(data.shape[0]), 1), dtype=bool)
for i in range(0,data.shape[0]):
id = sorted_data.head(int(data.shape[0])).index[i]
if id < data.shape[0]-1 and sorted_data[feature][id] == sorted_data[feature][id+1]:
continue
temp_mask[ id ] = True
mask = pd.DataFrame(temp_mask, columns=['mask'])
ig = information_gain(sorted_data, mask['mask'])
# print(feature, i, sorted_data[feature][i], ig)
if ig > best_ig:
best_ig = ig
best_threshold = sorted_data[feature][id]
best_feature = feature
# 15/5 2/8
return best_ig, best_threshold, best_feature
# [Note] You have to save the value of "ans_ig", "ans_value", and "ans_name" into the output file
ans_ig, ans_value, ans_name = find_best_split(input_data)
print("ans_ig = ", ans_ig)
print("ans_value = ", ans_value)
print("ans_name = ", ans_name)Editor is loading...