Untitled

mail@pastecode.io avatar
unknown
plain_text
20 days ago
2.9 kB
1
Indexable
Never

# import hyperspy.api as hs
# import atomap.api as am
import torch 
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from tqdm import tqdm

from model.models import *
from dataloader import *
from utils import *
from torch.autograd import Variable

if torch.cuda.is_available():
    DEVICE = 'cuda:0'
    print('Running on the GPU')
else:
    DEVICE = "cpu"
    print('Running on the CPU')

MODEL_PATH = './trained/CrI3'
LOAD_MODEL = False
ROOT_DIR = './DATASET/Train/wave'
# IMG_HEIGHT = 110  
# IMG_WIDTH = 220  
BATCH_SIZE = 32 
LEARNING_RATE = 0.002
EPOCHS = 10
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor

def train_function(data, model, optimizer, loss_fn, device):
    print('Entering into train function')
    loss_values = []    
    data = tqdm(data)
    model.train()
    for index, batch in enumerate(data):    
        X, y = batch        
        X = Variable(X).type(Tensor)
        y = Variable(y).type(Tensor)
        #X, y = X.to(device), y.to(device)        
        preds = model(X)  
        #print(preds.shape, y.shape)
        #_,preds = preds.max(dim=1)
        #print(preds.shape)
        preds = preds[:,None,:,:].type(Tensor)
        #preds = preds.type(torch.cuda.FloatTensor)
        #y = y.type(torch.cuda.FloatTensor)
        print(type(preds), type(y))
        loss = loss_fn(preds, y.squeeze())        
        optimizer.zero_grad()        
        loss.backward()        
        optimizer.step()  
   
    return loss.item()
def main():
    global epoch
    epoch = 0
    LOSS_VALS = []

    resize_transform = transforms.Compose([transforms.Resize((400, 400)),])# specify the size you want
    
    train_set = get_wave_data(root_dir=ROOT_DIR, 
                            batch_size=32,
                            transform=resize_transform)
    
    print('Data Loaded Successfully!')

    unet = UNet(n_channels = 1, n_classes=11).to(DEVICE).train()
    optimizer = optim.Adam(unet.parameters(), lr=LEARNING_RATE)
    loss_function = nn.CrossEntropyLoss()   

    if LOAD_MODEL:
        checkpoint = torch.load(MODEL_PATH)
        unet.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optim_state_dict'])
        epoch = checkpoint['epoch']+1
        LOSS_VALS = checkpoint['loss_values']
        print('Model successfully loaded!')

    for e in range(epoch, EPOCHS):
        print(f'Epoch: {e}')
        loss_val = train_function(train_set, unet, optimizer, loss_function, DEVICE)
        LOSS_VALS.append(loss_val) 
        torch.save({
            'model_state_dict': unet.state_dict(),
            'optim_state_dict': optimizer.state_dict(),
            'epoch': e,
            'loss_values': LOSS_VALS
        }, MODEL_PATH)
        print("Epoch completed and model successfully saved!")

if __name__ == '__main__':
    main()