Untitled

mail@pastecode.io avatar
unknown
python
2 years ago
4.5 kB
17
Indexable
Never
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchvision.datasets
import torchvision.transforms as transforms

from torch.utils.data import SequentialSampler


class LeNet(nn.Module):
    def __init__(self, num_classes, input_size=28):
        super(LeNet, self).__init__()
        self.feat_size = 500 if input_size==32 else 320 if input_size==28 else -1
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(self.feat_size, 50)
        self.fc2 = nn.Linear(50, num_classes)

    def forward(self, x):
        x1 = F.relu(F.max_pool2d(self.conv1(x), 2))
        x2 = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x1)), 2))
        x2 = x2.view(-1, self.feat_size)
        x3 = F.relu(self.fc1(x2))
        x4 = F.log_softmax(self.fc2(x3), dim=1)
        return x4

    def forward_features(self, x):
        x1 = F.relu(F.max_pool2d(self.conv1(x), 2))
        x2 = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x1)), 2))
        x2 = x2.view(-1, self.feat_size)
        x3 = F.relu(self.fc1(x2))
        x4 = F.log_softmax(self.fc2(x3), dim=1)
        return [x1, x2, x3, x4]

    def forward_param_features(self, x):
        return self.forward_features(x)


def get_model():
    return LeNet(num_classes=10)


def torch_get_function(network, loader):
    ''' Collect function (features) from the self.network.module.forward_features() routine '''
    features = []
    for batch_idx, (inputs, targets) in enumerate(loader):
        inputs, targets = inputs.to('cpu'), targets.to('cpu')

        features.append([f.cpu().data.numpy().astype(np.float16) for f in network.forward_features(inputs)])

    return [np.concatenate(list(zip(*features))[i]) for i in range(len(features[0]))]


def adjacency(signals, metric=None):
    '''
    Build matrix A  of dimensions nxn where a_{ij} = metric(a_i, a_j).
    signals: nxm matrix where each row (signal[k], k=range(n)) is a signal.
    metric: a function f(.,.) that takes two 1D ndarrays and outputs a single real number (e.g correlation, KL divergence etc).
    '''

    ''' Get input dimensions '''
    signals = np.reshape(signals, (signals.shape[0], -1))

    ''' If no metric provided fast-compute correlation  '''
    if not metric:
        return np.abs(np.nan_to_num(np.corrcoef(signals)))

    n, m = signals.shape
    A = np.zeros((n, n))

    for i in range(n):
        for j in range(n):
            A[i,j] = metric(signals[i], np.transpose(signals[j]))

    ''' Normalize '''
    A = robust_scaler(A)

    return np.abs(np.nan_to_num(A))


def robust_scaler(A, quantiles=[0.05, 0.95]):
    a = np.quantile(A, quantiles[0])
    b = np.quantile(A, quantiles[1])
    return (A-a)/(b-a)


def save_dipha(fname, adj):
    ''' Write adjacency to binary. To use as DIPHA input for persistence homology '''
    output_file = open(fname, 'wb')
    np.array(8067171840, dtype=np.int64).tofile(output_file)
    np.array(7, dtype=np.int64).tofile(output_file)
    np.array(adj.shape[0], dtype=np.int64).tofile(output_file)
    np.array(adj, dtype=np.double).T.tofile(output_file)


SAVE_DIR = 'results_dnn_topology/lenet_mnist'

dataset = 'mnist'
net = get_model()
batch_size = 100

criterion = nn.CrossEntropyLoss()

epochs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]

TRANSFORMS_MNIST = transforms.Compose([
    transforms.ToTensor(),
])

for epoch in epochs:
    dataset = torchvision.datasets.MNIST('./data', download=True, train=False, transform=TRANSFORMS_MNIST)
    dataset = torch.utils.data.Subset(dataset, list(range(0, 1000)))
    print(len(dataset), len(dataset[0]), len(dataset[0][0]), len(dataset[0][0][0]), len(dataset[0][0][0][0]))

    torch_checkpoint = torch.load('./checkpoint/' + 'lenet' + '_' + 'mnist' + '/ckpt_trial_0_epoch_' + str(epoch) + '.t7')
    net.load_state_dict(torch_checkpoint['net'])

    functloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        sampler=SequentialSampler(dataset),
        num_workers=0,
        drop_last=True
    )

    activs = torch_get_function(net, functloader)
    print(len(activs), len(activs[0]), len(activs[0][0]), len(activs[0][0][0]), len(activs[0][0][0][0]))
    activs = np.concatenate([np.transpose(x.reshape(x.shape[0], -1)) for x in activs], axis=0)
    adj = adjacency(activs)

    save_dipha(SAVE_DIR + '/adj_epc{}_trl0.bin'.format(epoch), 1-adj)