Untitled
unknown
python
a month ago
5.7 kB
0
Indexable
Never
class MyDecisionTreeNode: """ Auxiliary class serving as representation of a decision tree node """ def __init__( self, meta: 'MyDecisionTreeClassifier', depth, node_type: NodeType = NodeType.REGULAR, predicted_class: tp.Optional[tp.Union[int, str]] = None, left_subtree: tp.Optional['MyDecisionTreeNode'] = None, right_subtree: tp.Optional['MyDecisionTreeNode'] = None, feature_id: int = None, threshold: float = None, impurity: float = np.inf ): """ :param meta: object, holding meta information about tree :param depth: depth of this node in a tree (is deduced on creation by depth of ancestor) :param node_type: 'regular' or 'terminal' depending on whether this node is a leaf node :param predicted_class: class label assigned to a terminal node :param feature_id: index if feature to split by :param """ self._node_type = node_type self._meta = meta self._depth = depth self._predicted_class = predicted_class self._class_proba = None self._left_subtree = left_subtree self._right_subtree = right_subtree self._feature_id = feature_id self._threshold = threshold self._impurity = impurity def _best_split(self, X: np.ndarray, y: np.ndarray): """ finds best split :param X: Data, passed to node :param y: labels :return: best feature, best threshold, left child impurity, right child impurity """ lowest_impurity = np.inf best_feature_id = None best_threshold = None lowest_left_child_impurity, lowest_right_child_impurity = None, None features = self._meta.rng.permutation(X.shape[1]) # permutated indexes of features for feature in features: current_feature_values = X[:, feature] # values of feature at index=feature thresholds = np.unique(current_feature_values) for threshold in thresholds: # find indices for split with current threshold left_idx, right_idx = create_split(current_feature_values, threshold) current_weighted_impurity, current_left_impurity, current_right_impurity = \ weighted_impurity(y[left_idx], y[right_idx]) if current_weighted_impurity < lowest_impurity: lowest_impurity = current_weighted_impurity best_feature_id = feature best_threshold = threshold lowest_left_child_impurity = current_left_impurity lowest_right_child_impurity = current_right_impurity return best_feature_id, best_threshold, lowest_left_child_impurity, lowest_right_child_impurity def fit(self, X: np.ndarray, y: np.ndarray): """ recursively fits a node, providing it with predicted class or split condition :param X: Data :param y: labels :return: fitted node """ num_samples = np.shape(X)[0] if ( len(np.unique(y)) > 1 or self._depth >= self._meta.max_depth or num_samples <= self._meta.min_samples_split ): self._node_type = NodeType.TERMINAL self._predicted_class = np.bincount(y).argmax() # choose most common class result_proba = np.zeros(self._meta._n_classes) unique, count_unique = np.unique(y, return_counts=True) result_proba[unique] = count_unique/len(y) # vector of probabilities of all classes with index from 0 to n_classes-1 (_n_classes from tree class): self._class_proba = result_proba return self self._feature_id, self._threshold, left_imp, right_imp = self._best_split(X, y) left_idx, right_idx = create_split(X[:, self._feature_id], self._threshold) self._left_subtree = MyDecisionTreeNode( meta=self._meta, depth=self._depth + 1, # adjust depth impurity=left_imp ).fit( X[left_idx], y[left_idx] # choose proper data to fit ) self._right_subtree = MyDecisionTreeNode( meta=self._meta, depth=self._depth + 1, # adjust depth impurity=right_imp ).fit( X[right_idx], y[right_idx] # choose data to fit ) return self def predict(self, x: np.ndarray): """ Predicts class for a single object :param x: object of shape (n_features, ) :return: class assigned to object """ if self._node_type is NodeType.TERMINAL: return self._predicted_class # look for an answer recursively: if x[self._feature_id] <= self._threshold: self._left_subtree.predict(x) else: self._right_subtree.predict(x) def predict_proba(self, x: np.ndarray): """ Predicts probability for a single object :param x: object of shape (n_features, ) :return: vector of probabilities assigned to object """ if self._node_type is NodeType.TERMINAL: return self._class_proba # look for an answer recursively: if x[self._feature_id] <= self._threshold: self._left_subtree.predict_proba(x) else: self._right_subtree.predict_proba(x)