Untitled
unknown
python
2 years ago
3.4 kB
14
Indexable
import matplotlib.pyplot as plt
import os
# Create the directory if it doesn't exist
os.makedirs('../best_plot/', exist_ok=True)
# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=100, factor=0.5, verbose=True)
# Function to plot and save waveforms
def plot_waveforms(input_waveform, target_waveform, output_waveform, epoch, sample_rate=4000):
input_waveform = input_waveform.squeeze().numpy()
target_waveform = target_waveform.squeeze().numpy()
output_waveform = output_waveform.squeeze().detach().numpy()
plt.figure(figsize=(15, 5))
plt.subplot(3, 1, 1)
plt.plot(input_waveform)
plt.title('Input Waveform')
plt.subplot(3, 1, 2)
plt.plot(target_waveform)
plt.title('Target Waveform')
plt.subplot(3, 1, 3)
plt.plot(output_waveform)
plt.title('Output Waveform')
plt.tight_layout()
plt.savefig(f'../best_plot/waveform_comparison_epoch_{epoch}.png')
plt.show()
# Training loop
num_epochs = 1000
time_steps = 36864
best_val_acc = float('inf')
early_stop_threshold = 300
early_stop_counter = 0
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for inputs, targets, _ in tqdm(train_loader):
# Move inputs and targets to GPU
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
# Reshape outputs and targets
outputs = outputs.view(-1, time_steps)
targets = targets.view(-1, time_steps)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
epoch_loss = running_loss / len(train_loader.dataset)
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
model.eval()
val_loss = 0.0
with torch.no_grad():
for inputs, targets, _ in test_loader:
# Move inputs and targets to GPU
inputs, targets = inputs.to(device), targets.to(device)
outputs = model(inputs)
# Flatten the tensors
outputs = outputs.view(-1, time_steps)
targets = targets.view(-1, time_steps)
loss = criterion(outputs, targets)
val_loss += loss.item() * inputs.size(0)
score = val_loss / len(test_loader.dataset)
print(f'Epoch {epoch+1}/{num_epochs}, Val Loss: {score:.4f}')
if score < best_val_acc:
best_val_acc = score
torch.save(model.state_dict(), '../best_weight/segemnter1mse.pth')
print(f'--------------Validation Loss: {best_val_acc:.4f}--------------------')
# Plot and save the waveforms for a random sample
random_idx = torch.randint(len(inputs), (1,)).item()
plot_waveforms(inputs[random_idx].cpu(), targets[random_idx].cpu(), outputs[random_idx].cpu(), epoch)
early_stop_counter = 0
else:
early_stop_counter += 1
if early_stop_counter > early_stop_threshold:
print("Early stopping at epoch:", epoch + 1)
break
scheduler.step(best_val_acc)
Editor is loading...
Leave a Comment