Untitled

mail@pastecode.io avatar
unknown
python
a month ago
3.4 kB
2
Indexable
Never
import matplotlib.pyplot as plt
import os

# Create the directory if it doesn't exist
os.makedirs('../best_plot/', exist_ok=True)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=100, factor=0.5, verbose=True)

# Function to plot and save waveforms
def plot_waveforms(input_waveform, target_waveform, output_waveform, epoch, sample_rate=4000):
    input_waveform = input_waveform.squeeze().numpy()
    target_waveform = target_waveform.squeeze().numpy()
    output_waveform = output_waveform.squeeze().detach().numpy()
    
    plt.figure(figsize=(15, 5))

    plt.subplot(3, 1, 1)
    plt.plot(input_waveform)
    plt.title('Input Waveform')
    
    plt.subplot(3, 1, 2)
    plt.plot(target_waveform)
    plt.title('Target Waveform')
    
    plt.subplot(3, 1, 3)
    plt.plot(output_waveform)
    plt.title('Output Waveform')
    
    plt.tight_layout()
    plt.savefig(f'../best_plot/waveform_comparison_epoch_{epoch}.png')
    plt.show()

# Training loop
num_epochs = 1000
time_steps = 36864
best_val_acc = float('inf')
early_stop_threshold = 300
early_stop_counter = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, targets, _ in tqdm(train_loader):
        # Move inputs and targets to GPU
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        
        # Reshape outputs and targets
        outputs = outputs.view(-1, time_steps)
        targets = targets.view(-1, time_steps)
        
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * inputs.size(0)
    
    epoch_loss = running_loss / len(train_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
    
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for inputs, targets, _ in test_loader:
            # Move inputs and targets to GPU
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)

            # Flatten the tensors
            outputs = outputs.view(-1, time_steps)
            targets = targets.view(-1, time_steps)

            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)
    
    score = val_loss / len(test_loader.dataset)
    print(f'Epoch {epoch+1}/{num_epochs}, Val Loss: {score:.4f}')
    
    if score < best_val_acc:
        best_val_acc = score
        torch.save(model.state_dict(), '../best_weight/segemnter1mse.pth')
        print(f'--------------Validation Loss: {best_val_acc:.4f}--------------------')
        
        # Plot and save the waveforms for a random sample
        random_idx = torch.randint(len(inputs), (1,)).item()
        plot_waveforms(inputs[random_idx].cpu(), targets[random_idx].cpu(), outputs[random_idx].cpu(), epoch)
        
        early_stop_counter = 0
    else:
        early_stop_counter += 1
    
    if early_stop_counter > early_stop_threshold:
        print("Early stopping at epoch:", epoch + 1)
        break
    
    scheduler.step(best_val_acc)
Leave a Comment