Untitled
unknown
python
a year ago
3.4 kB
4
Indexable
import matplotlib.pyplot as plt import os # Create the directory if it doesn't exist os.makedirs('../best_plot/', exist_ok=True) # Define loss function and optimizer criterion = nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=100, factor=0.5, verbose=True) # Function to plot and save waveforms def plot_waveforms(input_waveform, target_waveform, output_waveform, epoch, sample_rate=4000): input_waveform = input_waveform.squeeze().numpy() target_waveform = target_waveform.squeeze().numpy() output_waveform = output_waveform.squeeze().detach().numpy() plt.figure(figsize=(15, 5)) plt.subplot(3, 1, 1) plt.plot(input_waveform) plt.title('Input Waveform') plt.subplot(3, 1, 2) plt.plot(target_waveform) plt.title('Target Waveform') plt.subplot(3, 1, 3) plt.plot(output_waveform) plt.title('Output Waveform') plt.tight_layout() plt.savefig(f'../best_plot/waveform_comparison_epoch_{epoch}.png') plt.show() # Training loop num_epochs = 1000 time_steps = 36864 best_val_acc = float('inf') early_stop_threshold = 300 early_stop_counter = 0 for epoch in range(num_epochs): model.train() running_loss = 0.0 for inputs, targets, _ in tqdm(train_loader): # Move inputs and targets to GPU inputs, targets = inputs.to(device), targets.to(device) optimizer.zero_grad() outputs = model(inputs) # Reshape outputs and targets outputs = outputs.view(-1, time_steps) targets = targets.view(-1, time_steps) loss = criterion(outputs, targets) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) epoch_loss = running_loss / len(train_loader.dataset) print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}') model.eval() val_loss = 0.0 with torch.no_grad(): for inputs, targets, _ in test_loader: # Move inputs and targets to GPU inputs, targets = inputs.to(device), targets.to(device) outputs = model(inputs) # Flatten the tensors outputs = outputs.view(-1, time_steps) targets = targets.view(-1, time_steps) loss = criterion(outputs, targets) val_loss += loss.item() * inputs.size(0) score = val_loss / len(test_loader.dataset) print(f'Epoch {epoch+1}/{num_epochs}, Val Loss: {score:.4f}') if score < best_val_acc: best_val_acc = score torch.save(model.state_dict(), '../best_weight/segemnter1mse.pth') print(f'--------------Validation Loss: {best_val_acc:.4f}--------------------') # Plot and save the waveforms for a random sample random_idx = torch.randint(len(inputs), (1,)).item() plot_waveforms(inputs[random_idx].cpu(), targets[random_idx].cpu(), outputs[random_idx].cpu(), epoch) early_stop_counter = 0 else: early_stop_counter += 1 if early_stop_counter > early_stop_threshold: print("Early stopping at epoch:", epoch + 1) break scheduler.step(best_val_acc)
Editor is loading...
Leave a Comment