class MyDecisionTreeNode:
"""
Auxiliary class serving as representation of a decision tree node
"""
def __init__(
self,
meta: 'MyDecisionTreeClassifier',
depth,
node_type: NodeType = NodeType.REGULAR,
predicted_class: tp.Optional[tp.Union[int, str]] = None,
left_subtree: tp.Optional['MyDecisionTreeNode'] = None,
right_subtree: tp.Optional['MyDecisionTreeNode'] = None,
feature_id: int = None,
threshold: float = None,
impurity: float = np.inf
):
"""
:param meta: object, holding meta information about tree
:param depth: depth of this node in a tree (is deduced on creation by depth of ancestor)
:param node_type: 'regular' or 'terminal' depending on whether this node is a leaf node
:param predicted_class: class label assigned to a terminal node
:param feature_id: index if feature to split by
:param
"""
self._node_type = node_type
self._meta = meta
self._depth = depth
self._predicted_class = predicted_class
self._class_proba = None
self._left_subtree = left_subtree
self._right_subtree = right_subtree
self._feature_id = feature_id
self._threshold = threshold
self._impurity = impurity
def _best_split(self, X: np.ndarray, y: np.ndarray):
"""
finds best split
:param X: Data, passed to node
:param y: labels
:return: best feature, best threshold, left child impurity, right child impurity
"""
lowest_impurity = np.inf
best_feature_id = None
best_threshold = None
lowest_left_child_impurity, lowest_right_child_impurity = None, None
features = self._meta.rng.permutation(X.shape[1]) # permutated indexes of features
for feature in features:
current_feature_values = X[:, feature] # values of feature at index=feature
thresholds = np.unique(current_feature_values)
for threshold in thresholds:
# find indices for split with current threshold
left_idx, right_idx = create_split(current_feature_values, threshold)
current_weighted_impurity, current_left_impurity, current_right_impurity = \
weighted_impurity(y[left_idx], y[right_idx])
if current_weighted_impurity < lowest_impurity:
lowest_impurity = current_weighted_impurity
best_feature_id = feature
best_threshold = threshold
lowest_left_child_impurity = current_left_impurity
lowest_right_child_impurity = current_right_impurity
return best_feature_id, best_threshold, lowest_left_child_impurity, lowest_right_child_impurity
def fit(self, X: np.ndarray, y: np.ndarray):
"""
recursively fits a node, providing it with predicted class or split condition
:param X: Data
:param y: labels
:return: fitted node
"""
num_samples = np.shape(X)[0]
if (
len(np.unique(y)) > 1
or
self._depth >= self._meta.max_depth
or
num_samples <= self._meta.min_samples_split
):
self._node_type = NodeType.TERMINAL
self._predicted_class = np.bincount(y).argmax() # choose most common class
result_proba = np.zeros(self._meta._n_classes)
unique, count_unique = np.unique(y, return_counts=True)
result_proba[unique] = count_unique/len(y)
# vector of probabilities of all classes with index from 0 to n_classes-1 (_n_classes from tree class):
self._class_proba = result_proba
return self
self._feature_id, self._threshold, left_imp, right_imp = self._best_split(X, y)
left_idx, right_idx = create_split(X[:, self._feature_id], self._threshold)
self._left_subtree = MyDecisionTreeNode(
meta=self._meta,
depth=self._depth + 1, # adjust depth
impurity=left_imp
).fit(
X[left_idx], y[left_idx] # choose proper data to fit
)
self._right_subtree = MyDecisionTreeNode(
meta=self._meta,
depth=self._depth + 1, # adjust depth
impurity=right_imp
).fit(
X[right_idx], y[right_idx] # choose data to fit
)
return self
def predict(self, x: np.ndarray):
"""
Predicts class for a single object
:param x: object of shape (n_features, )
:return: class assigned to object
"""
if self._node_type is NodeType.TERMINAL:
return self._predicted_class
# look for an answer recursively:
if x[self._feature_id] <= self._threshold:
self._left_subtree.predict(x)
else:
self._right_subtree.predict(x)
def predict_proba(self, x: np.ndarray):
"""
Predicts probability for a single object
:param x: object of shape (n_features, )
:return: vector of probabilities assigned to object
"""
if self._node_type is NodeType.TERMINAL:
return self._class_proba
# look for an answer recursively:
if x[self._feature_id] <= self._threshold:
self._left_subtree.predict_proba(x)
else:
self._right_subtree.predict_proba(x)