Untitled
unknown
plain_text
2 years ago
2.9 kB
6
Indexable
# import hyperspy.api as hs
# import atomap.api as am
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from tqdm import tqdm
from model.models import *
from dataloader import *
from utils import *
from torch.autograd import Variable
if torch.cuda.is_available():
DEVICE = 'cuda:0'
print('Running on the GPU')
else:
DEVICE = "cpu"
print('Running on the CPU')
MODEL_PATH = './trained/CrI3'
LOAD_MODEL = False
ROOT_DIR = './DATASET/Train/wave'
# IMG_HEIGHT = 110
# IMG_WIDTH = 220
BATCH_SIZE = 32
LEARNING_RATE = 0.002
EPOCHS = 10
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
def train_function(data, model, optimizer, loss_fn, device):
print('Entering into train function')
loss_values = []
data = tqdm(data)
model.train()
for index, batch in enumerate(data):
X, y = batch
X = Variable(X).type(Tensor)
y = Variable(y).type(Tensor)
#X, y = X.to(device), y.to(device)
preds = model(X)
#print(preds.shape, y.shape)
#_,preds = preds.max(dim=1)
#print(preds.shape)
preds = preds[:,None,:,:].type(Tensor)
#preds = preds.type(torch.cuda.FloatTensor)
#y = y.type(torch.cuda.FloatTensor)
print(type(preds), type(y))
loss = loss_fn(preds, y.squeeze())
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def main():
global epoch
epoch = 0
LOSS_VALS = []
resize_transform = transforms.Compose([transforms.Resize((400, 400)),])# specify the size you want
train_set = get_wave_data(root_dir=ROOT_DIR,
batch_size=32,
transform=resize_transform)
print('Data Loaded Successfully!')
unet = UNet(n_channels = 1, n_classes=11).to(DEVICE).train()
optimizer = optim.Adam(unet.parameters(), lr=LEARNING_RATE)
loss_function = nn.CrossEntropyLoss()
if LOAD_MODEL:
checkpoint = torch.load(MODEL_PATH)
unet.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optim_state_dict'])
epoch = checkpoint['epoch']+1
LOSS_VALS = checkpoint['loss_values']
print('Model successfully loaded!')
for e in range(epoch, EPOCHS):
print(f'Epoch: {e}')
loss_val = train_function(train_set, unet, optimizer, loss_function, DEVICE)
LOSS_VALS.append(loss_val)
torch.save({
'model_state_dict': unet.state_dict(),
'optim_state_dict': optimizer.state_dict(),
'epoch': e,
'loss_values': LOSS_VALS
}, MODEL_PATH)
print("Epoch completed and model successfully saved!")
if __name__ == '__main__':
main()Editor is loading...