Untitled
unknown
python
4 years ago
4.9 kB
6
Indexable
#!/usr/bin/env python3 from collections import Counter from math import log from typing import List, Tuple, Dict import sys # Used for typing Data = List[List] def read(file_name: str) -> Tuple[List[str], Data]: """ t3: Load the data into a bidimensional list. Return the headers as a list, and the data """ headers = None data = list() with open(file_name) as file_: for line in file_: values = line.strip().split(',') if headers is None: headers = values else: data.append(list(map(_parse_value, values))) return headers, data def _parse_value(value: str): if value.isnumeric(): return float(value) if '.' in value else int(value) else: return value # try: # Bona manera, a python es millor demanar perdó que permís # return float(value) # except ValueError: # return value def unique_counts(part: Data) -> Dict[str, int]: """ t4: Create counts of possible results (the last column of each row is the result) """ results = Counter() for row in part: label = row[-1] results[label] += 1 return results def gini_impurity(part: Data): """ t5: Computes the Gini index of a node """ total = len(part) if total == 0: return 0 results = unique_counts(part) imp = 1 for value in results.values(): p = value / total imp -= p ** 2 return imp def _log2(value: float): return log(value) / log(2) def entropy(rows: Data): """ t6: Entropy is the sum of p(x)log(p(x)) across all the different possible results """ results = unique_counts(rows) total = len(rows) imp = 0 for value in results.values(): p = value / total dec = p * _log2(p) imp -= dec return imp def _split_numeric(prototype: List, column: int, value): return prototype[column] >= value def _split_categorical(prototype: List, column: int, value): return prototype[column] == value def divideset(part: Data, column: int, value) -> Tuple[Data, Data]: """ t7: Divide a set on a specific column. Can handle numeric or categorical values """ set1 = [] set2 = [] if isinstance(value, (int, float)): split_function = _split_numeric else: split_function = _split_categorical for row in part: set1.append(row) if split_function(row, column, value) else set2.append(row) # ... return set1, set2 class DecisionNode: def __init__(self, col=-1, value=None, results=None, tb=None, fb=None): """ t8: We have 5 member variables: - col is the column index which represents the attribute we use to split the node - value corresponds to the answer that satisfies the question - tb and fb are internal nodes representing the positive and negative answers, respectively - results is a dictionary that stores the result for this branch. Is None except for the leaves """ raise NotImplementedError def buildtree(part: Data, scoref=entropy, beta=0): """ t9: Define a new function buildtree. This is a recursive function that builds a decision tree using any of the impurity measures we have seen. The stop criterion is max_s\Delta i(s,t) < \beta """ if len(part) == 0: return DecisionNode() current_score = scoref(part) # Set up some variables to track the best criteria best_gain = 0 best_criteria = None best_sets = None # ... # else: # return DecisionNode(results=unique_counts(part)) def iterative_buildtree(part: Data, scoref=entropy, beta=0): """ t10: Define the iterative version of the function buildtree """ raise NotImplementedError def print_tree(tree, headers=None, indent=""): """ t11: Include the following function """ # Is this a leaf node? if tree.results is not None: print(tree.results) else: # Print the criteria criteria = tree.col if headers: criteria = headers[criteria] print(f"{indent}{criteria}: {tree.value}?") # Print the branches print(f"{indent}T->") print_tree(tree.tb, headers, indent + " ") print(f"{indent}F->") print_tree(tree.fb, headers, indent + " ") def main(): header, data = read(sys.argv[1]) print(header) for row in data: print(row) for value in ["USA", "fRANCE", "UK", "NewZealand", ]: set1, set2 = divideset(data, 1, value) print("Split by: ", value) print(gini_impurity(set1)) print(gini_impurity(set2)) print(entropy(set1)) print(entropy(set2)) if __name__ == "__main__": main()
Editor is loading...