Untitled

 avatar
unknown
python
4 years ago
4.9 kB
6
Indexable
#!/usr/bin/env python3
from collections import Counter
from math import log
from typing import List, Tuple, Dict
import sys

# Used for typing
Data = List[List]


def read(file_name: str) -> Tuple[List[str], Data]:
    """
    t3: Load the data into a bidimensional list.
    Return the headers as a list, and the data
    """
    headers = None
    data = list()
    with open(file_name) as file_:
        for line in file_:
            values = line.strip().split(',')
            if headers is None:
                headers = values
            else:
                data.append(list(map(_parse_value, values)))
    return headers, data


def _parse_value(value: str):
    if value.isnumeric():
        return float(value) if '.' in value else int(value)
    else:
        return value
    # try:  # Bona manera, a python es millor demanar perdó que permís
    #     return float(value)
    # except ValueError:
    #     return value


def unique_counts(part: Data) -> Dict[str, int]:
    """
    t4: Create counts of possible results
    (the last column of each row is the
    result)
    """
    results = Counter()
    for row in part:
        label = row[-1]
        results[label] += 1
    return results


def gini_impurity(part: Data):
    """
    t5: Computes the Gini index of a node
    """
    total = len(part)
    if total == 0:
        return 0
    results = unique_counts(part)
    imp = 1

    for value in results.values():
        p = value / total
        imp -= p ** 2
    return imp


def _log2(value: float):
    return log(value) / log(2)


def entropy(rows: Data):
    """
    t6: Entropy is the sum of p(x)log(p(x))
    across all the different possible results
    """
    results = unique_counts(rows)
    total = len(rows)
    imp = 0

    for value in results.values():
        p = value / total
        dec = p * _log2(p)
        imp -= dec

    return imp


def _split_numeric(prototype: List, column: int, value):
    return prototype[column] >= value


def _split_categorical(prototype: List, column: int, value):
    return prototype[column] == value


def divideset(part: Data, column: int, value) -> Tuple[Data, Data]:
    """
    t7: Divide a set on a specific column. Can handle
    numeric or categorical values
    """
    set1 = []
    set2 = []
    if isinstance(value, (int, float)):
        split_function = _split_numeric
    else:
        split_function = _split_categorical
    for row in part:
        set1.append(row) if split_function(row, column, value) else set2.append(row)
    # ...
    return set1, set2


class DecisionNode:
    def __init__(self, col=-1, value=None, results=None, tb=None, fb=None):
        """
        t8: We have 5 member variables:
        - col is the column index which represents the
          attribute we use to split the node
        - value corresponds to the answer that satisfies
          the question
        - tb and fb are internal nodes representing the
          positive and negative answers, respectively
        - results is a dictionary that stores the result
          for this branch. Is None except for the leaves
        """
        raise NotImplementedError


def buildtree(part: Data, scoref=entropy, beta=0):
    """
    t9: Define a new function buildtree. This is a recursive function
    that builds a decision tree using any of the impurity measures we
    have seen. The stop criterion is max_s\Delta i(s,t) < \beta
    """
    if len(part) == 0:
        return DecisionNode()

    current_score = scoref(part)

    # Set up some variables to track the best criteria
    best_gain = 0
    best_criteria = None
    best_sets = None
    # ...
    # else:
    #    return DecisionNode(results=unique_counts(part))


def iterative_buildtree(part: Data, scoref=entropy, beta=0):
    """
    t10: Define the iterative version of the function buildtree
    """
    raise NotImplementedError


def print_tree(tree, headers=None, indent=""):
    """
    t11: Include the following function
    """
    # Is this a leaf node?
    if tree.results is not None:
        print(tree.results)
    else:
        # Print the criteria
        criteria = tree.col
        if headers:
            criteria = headers[criteria]
        print(f"{indent}{criteria}: {tree.value}?")

        # Print the branches
        print(f"{indent}T->")
        print_tree(tree.tb, headers, indent + "  ")
        print(f"{indent}F->")
        print_tree(tree.fb, headers, indent + "  ")


def main():
    header, data = read(sys.argv[1])
    print(header)
    for row in data:
        print(row)

    for value in ["USA", "fRANCE", "UK", "NewZealand", ]:
        set1, set2 = divideset(data, 1, value)
        print("Split by: ", value)
        print(gini_impurity(set1))
        print(gini_impurity(set2))
        print(entropy(set1))
        print(entropy(set2))


if __name__ == "__main__":
    main()
Editor is loading...