Untitled
unknown
plain_text
20 days ago
2.9 kB
1
Indexable
Never
# import hyperspy.api as hs # import atomap.api as am import torch import torch.optim as optim import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from tqdm import tqdm from model.models import * from dataloader import * from utils import * from torch.autograd import Variable if torch.cuda.is_available(): DEVICE = 'cuda:0' print('Running on the GPU') else: DEVICE = "cpu" print('Running on the CPU') MODEL_PATH = './trained/CrI3' LOAD_MODEL = False ROOT_DIR = './DATASET/Train/wave' # IMG_HEIGHT = 110 # IMG_WIDTH = 220 BATCH_SIZE = 32 LEARNING_RATE = 0.002 EPOCHS = 10 Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor def train_function(data, model, optimizer, loss_fn, device): print('Entering into train function') loss_values = [] data = tqdm(data) model.train() for index, batch in enumerate(data): X, y = batch X = Variable(X).type(Tensor) y = Variable(y).type(Tensor) #X, y = X.to(device), y.to(device) preds = model(X) #print(preds.shape, y.shape) #_,preds = preds.max(dim=1) #print(preds.shape) preds = preds[:,None,:,:].type(Tensor) #preds = preds.type(torch.cuda.FloatTensor) #y = y.type(torch.cuda.FloatTensor) print(type(preds), type(y)) loss = loss_fn(preds, y.squeeze()) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item() def main(): global epoch epoch = 0 LOSS_VALS = [] resize_transform = transforms.Compose([transforms.Resize((400, 400)),])# specify the size you want train_set = get_wave_data(root_dir=ROOT_DIR, batch_size=32, transform=resize_transform) print('Data Loaded Successfully!') unet = UNet(n_channels = 1, n_classes=11).to(DEVICE).train() optimizer = optim.Adam(unet.parameters(), lr=LEARNING_RATE) loss_function = nn.CrossEntropyLoss() if LOAD_MODEL: checkpoint = torch.load(MODEL_PATH) unet.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optim_state_dict']) epoch = checkpoint['epoch']+1 LOSS_VALS = checkpoint['loss_values'] print('Model successfully loaded!') for e in range(epoch, EPOCHS): print(f'Epoch: {e}') loss_val = train_function(train_set, unet, optimizer, loss_function, DEVICE) LOSS_VALS.append(loss_val) torch.save({ 'model_state_dict': unet.state_dict(), 'optim_state_dict': optimizer.state_dict(), 'epoch': e, 'loss_values': LOSS_VALS }, MODEL_PATH) print("Epoch completed and model successfully saved!") if __name__ == '__main__': main()