Untitled

mail@pastecode.io avatar
unknown
plain_text
2 years ago
14 kB
2
Indexable
Never
import math, shutil, os, time, argparse
import numpy as np
import scipy.io as sio
import torch.nn as nn
import torch
import torch.nn.parallel
import torch.optim
import torch.utils.data as DATA
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
from PIL import Image
from imutils import face_utils
import imutils
import dlib
import cv2
import torch.backends.cudnn as cudnn
from ITrackerModel import ITrackerModel
import statistics
import matplotlib
import random
import matplotlib.pyplot as plt



class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

CHECKPOINTS_PATH = '.'
epochs = 5
batch_size = torch.cuda.device_count()*100 # Change if out of cuda memory
workers = 12
base_lr = 0.0001
momentum = 0.9
weight_decay = 1e-4
print_freq = 10
prec1 = 0
best_prec1 = 1e20
lr = base_lr
count_test = 0
count = 0


def load_checkpoint(filename='tl_checkpoint.pth.tar'):
    filename = os.path.join(CHECKPOINTS_PATH, filename)
    print(filename)
    if not os.path.isfile(filename):
        return None
    state = torch.load(filename)
    return state

def save_checkpoint(state, is_best, filename='1019checkpoint.pth.tar'):
    if not os.path.isdir(CHECKPOINTS_PATH):
        os.makedirs(CHECKPOINTS_PATH, 0o777)
    bestFilename = os.path.join(CHECKPOINTS_PATH, 'best_' + filename)
    filename = os.path.join(CHECKPOINTS_PATH, filename)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, bestFilename)

#
# def transfer_model(model):
#     a=0
#     for i in model.parameters():
#         if i.requires_grad:
#             # print(i)
#             # print(i.size())
#             i.requires_grad = False
#             if a > 17 and a != 20 and a !=21:
#                 i.requires_grad = True
#                 # print(a)
#             a=a+1
#     return model

def train(train_loader, model, criterion,optimizer, epoch):
    global count
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    Eim = AverageMeter()
    Eis = AverageMeter()
    Eim_x = AverageMeter()
    Eis_x = AverageMeter()
    Eim_y = AverageMeter()
    Eis_y = AverageMeter()

    # switch to train mode

    model.train()

    end = time.time()

    for i, (imFace, imEyeL, imEyeR, faceGrid, gaze) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)
        imFace = imFace.cuda()
        imEyeL = imEyeL.cuda()
        imEyeR = imEyeR.cuda()
        faceGrid = faceGrid.cuda()
        gaze = gaze.cuda()

        imFace = torch.autograd.Variable(imFace, requires_grad=True)
        imEyeL = torch.autograd.Variable(imEyeL, requires_grad=True)
        imEyeR = torch.autograd.Variable(imEyeR, requires_grad=True)
        faceGrid = torch.autograd.Variable(faceGrid, requires_grad=True)
        gaze = torch.autograd.Variable(gaze, requires_grad=False)
        a = 0
        # for j in model.parameters():
        #     a = a + 1
        #     if j.requires_grad:       Triple check
        #         print(a)

        # compute output
        output = model(imFace, imEyeL, imEyeR, faceGrid)

        loss = criterion(output, gaze)
        Ei = output - gaze
        Ei = torch.mul(Ei, Ei)
        Eim_x = torch.sqrt(Ei)
        # print("Eim_x= ", Eim_x)
        print("sqrt Eim_x= ", torch.sum(Eim_x, 0)[0])
        print("sqrt Eim_y= ", torch.sum(Eim_x, 0)[1])
        Ei = torch.sqrt(torch.sum(Ei, 1))
        Eimu = torch.mean(Ei)
        Eisigma = torch.std(Ei)
        losses.update(loss.data.item(), imFace.size(0))
        Eim.update(Eimu.item(), imFace.size(0))
        Eis.update(Eisigma.item(), imFace.size(0))

        losses.update(loss.data.item(), imFace.size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        count = count + 1

        print('Epoch (train): [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t''Mu and S {Eimu.val:.4f},{Eisigma.val:.4f} ({Eimu.avg:.4f},{Eisigma.avg:.4f})\t'.format(
            epoch, i, len(train_loader), batch_time=batch_time,data_time=data_time, loss=losses, Eimu=Eim, Eisigma=Eis))

    return Eim.avg


def validate(val_loader, model, criterion, epoch):
    global count_test
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    lossesLin = AverageMeter()
    Eis = AverageMeter()
    Eim = AverageMeter()
    # switch to evaluate mode
    model.eval()
    end = time.time()


    oIndex = 0
    for i, (imFace, imEyeL, imEyeR, faceGrid, gaze) in enumerate(val_loader):
        # measure data loading time
        # print(epoch)
      
        data_time.update(time.time() - end)
        imFace = imFace.cuda()
        imEyeL = imEyeL.cuda()
        imEyeR = imEyeR.cuda()
        faceGrid = faceGrid.cuda()
        gaze = gaze.cuda()

        imFace = torch.autograd.Variable(imFace, requires_grad = False)
        imEyeL = torch.autograd.Variable(imEyeL, requires_grad = False)
        imEyeR = torch.autograd.Variable(imEyeR, requires_grad = False)
        faceGrid = torch.autograd.Variable(faceGrid, requires_grad = False)
        gaze = torch.autograd.Variable(gaze, requires_grad = False)

        # compute output
        with torch.no_grad():
            output = model(imFace, imEyeL, imEyeR, faceGrid)

        loss = criterion(output, gaze)

        # lossLin = average of loss
        lossLin = output - gaze

        lossLin = torch.mul(lossLin,lossLin) 

        # sum of rows
        lossLin = torch.sum(lossLin,1)
        lossLin = torch.mean(torch.sqrt(lossLin))

        # Ei = sqrt((x-xi)^2+(y-yi)^2)
        Ei = output - gaze
        Ei = torch.mul(Ei,Ei)
        Ei = torch.sqrt(torch.sum(Ei, 1))
        Eimu = torch.mean(Ei)
        Eisigma = torch.std(Ei)

        losses.update(loss.data.item(), imFace.size(0))
        lossesLin.update(lossLin.item(), imFace.size(0))
        Eim.update(Eimu.item(), imFace.size(0))
        Eis.update(Eisigma.item(), imFace.size(0))
        # compute gradient and do SGD step
        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        with open('outputlog.txt', 'w') as f:
            f.write(str(output))
        with open('gazelog.txt', 'w') as f:
            f.write(str(gaze))
        glist = gaze.tolist()
        olist = output.tolist()
        print('Epoch (val): [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Error L2 {lossLin.val:.4f} ({lossLin.avg:.4f})\t''Mu and S {Eimu.val:.4f},{Eisigma.val:.4f} ({Eimu.avg:.4f},{Eisigma.avg:.4f})\t'.format(
                    epoch, i, len(val_loader), batch_time=batch_time,
                   loss=losses,lossLin=lossesLin,Eimu=Eim,Eisigma=Eis))
    # for g in glist:
    #     plt.scatter(g[0],g[1],marker='x', color = 'red', label='original gaze')
    # for o in olist:
    #     plt.scatter(o[0],o[1], marker='*', color = 'black', label='output gaze')



    # htop = input("press any key to continue next epoch...")
    # plt.show()
    return lossesLin.avg

def adjust_learning_rate(optimizer, epoch):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = base_lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.state_dict()['param_groups']:
        param_group['lr'] = lr


def main():

    global best_prec1, weight_decay, momentum
    data_group = 0
    test_group = 0
    plot_test_list = []
    plot_train_list = []
    for f in os.listdir('photo_train_pt/'):
        data_group += 1
    for f in os.listdir('photo_test_pt/'):
        test_group += 1
    isTest = False
    model = ITrackerModel()
    model = torch.nn.DataParallel(model)
    model.cuda()
    cudnn.benchmark = True
    saved = load_checkpoint()

    if saved:
        print(
            'Loading checkpoint for epoch %05d with loss %.5f (which is the mean squared error not the actual linear error)...' % (
            saved['epoch'], saved['best_prec1']))
        state = saved['state_dict']
        try:
            model.module.load_state_dict(state)
        except:
            model.load_state_dict(state)
        epoch = saved['epoch']
        best_prec1 = saved['best_prec1']
    else:
        print('Warning: Could not read checkpoint!')

    epoch = 0
    tl_model =model
    # tl_model = transfer_model(model)
    criterion = nn.MSELoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr,
                                momentum=momentum,
                                weight_decay=weight_decay)

    for epoch in range(0, epoch):
        adjust_learning_rate(optimizer, epoch)
    early_stop_trigger = 0
    times= 0
    for epoch in range(epoch, epochs):
        times+=1
        adjust_learning_rate(optimizer, epoch)
        random_list = random.sample(range(int(data_group // 5)), int(data_group // 5))
        test_random_list = random.sample(range(int(test_group // 5)), int(test_group // 5))
        test_average = []
        train_average = []
        for random_index in random_list:

            print('load input',random_index)

            GL = torch.tensor([])
            GR = torch.tensor([])
            GF = torch.tensor([])
            GFg = torch.tensor([])

            input1 = torch.load('photo_train_pt/gl_'+str(random_index)+'.pt')
            input2 = torch.load('photo_train_pt/gr_'+str(random_index)+'.pt')
            input3 = torch.load('photo_train_pt/gf_'+str(random_index)+'.pt')
            input4 = torch.load('photo_train_pt/gfg_'+str(random_index)+'.pt')
            input5 = torch.load('photo_train_pt/gtr_'+str(random_index)+'.pt')
            print(len(input1),len(input2),len(input3),len(input4),len(input5))
            dataTrain = DATA.TensorDataset(input1, input2, input3, input4, input5)

            del input1
            del input2
            del input3
            del input4
            del input5

            print("loading train input complete")

            train_loader = torch.utils.data.DataLoader(
                dataTrain,
                batch_size=batch_size, shuffle=True,
                num_workers=workers, pin_memory=True)

            # train for one epoch
            t1 = train(train_loader, tl_model, criterion, optimizer, epoch)
            train_average.append(t1)

            del train_loader
        for i in range(len(test_random_list)):
            print('loading test input', str(test_random_list[i]))

            Tinput1 = torch.load('photo_test_pt/tgl_' + str(test_random_list[i]) + '.pt')
            Tinput2 = torch.load('photo_test_pt/tgr_' + str(test_random_list[i]) + '.pt')
            Tinput3 = torch.load('photo_test_pt/tgf_' + str(test_random_list[i]) + '.pt')
            Tinput4 = torch.load('photo_test_pt/tgfg_' + str(test_random_list[i]) + '.pt')
            Tinput5 = torch.load('photo_test_pt/tgtr_' + str(test_random_list[i]) + '.pt')

            dataVal = DATA.TensorDataset(Tinput1, Tinput2, Tinput3, Tinput4, Tinput5)

            del Tinput1
            del Tinput2
            del Tinput3
            del Tinput4
            del Tinput5

            val_loader = torch.utils.data.DataLoader(
                dataVal,
                batch_size=batch_size, shuffle=False,
                num_workers=workers, pin_memory=True)

            print("loading test input complete")
            # evaluate on validation set        print('train 1 time')

            prec1 = validate(val_loader, tl_model, criterion, epoch)
            test_average.append(prec1)

        test_mean = statistics.mean(test_average)
        train_mean = statistics.mean(train_average)
        plot_test_list.append(test_mean)
        plot_train_list.append(train_mean)
        # remember best prec@1 and save checkpoint
        is_best = test_mean < best_prec1
        best_prec1 = min(test_mean, best_prec1)
        if is_best:
            early_stop_trigger = 0
        else:
            early_stop_trigger += 1
            if early_stop_trigger > 10:
                break
        save_checkpoint({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
        }, is_best)
    plt.figure(figsize=(150, 100), dpi=100, linewidth=2)

    num = []
    for i in range(0, times):
        num.append(i)
    plt.plot(num, plot_train_list, color='r', label="train")
    plt.plot(num, plot_test_list, color='b', label="test")
    # plt.plot(data_X['frame'], data_X[' gaze_0_x'],  color='r', label="x1")
    # plt.plot(data_X['frame'], data_X[' gaze_1_x'],  color='b', label="x2")
    plt.legend(loc="best", fontsize=20)

    plt.xlabel("epoch(s)", fontsize=30, labelpad=15)

    # 標示y軸(labelpad代表與圖片的距離)

    plt.ylabel("distance error", fontsize=30, labelpad=20)
    plt.show()



if __name__ == "__main__":
    main()
    print('DONE')