Untitled
import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets, transforms from torchvision.utils import save_image, make_grid import os from torch.utils.tensorboard import SummaryWriter # 导入 TensorBoard # 定义卷积VAE模型 class VAE(nn.Module): def __init__(self, latent_dim=64, in_channels=3): super(VAE, self).__init__() modules = [] hidden_dims = [32, 64, 128, 256] # Build Encoder for h_dim in hidden_dims: modules.append( nn.Sequential( nn.Conv2d(in_channels, out_channels=h_dim, kernel_size= 3, stride= 2, padding = 1), nn.BatchNorm2d(h_dim), nn.LeakyReLU()) ) in_channels = h_dim self.encoder = nn.Sequential(*modules) self.fc_mu = nn.Linear(hidden_dims[-1]*2*2, latent_dim) self.fc_var = nn.Linear(hidden_dims[-1]*2*2, latent_dim) # Build Decoder modules = [] self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 2 * 2) hidden_dims.reverse() for i in range(len(hidden_dims) - 1): modules.append( nn.Sequential( nn.ConvTranspose2d(hidden_dims[i], hidden_dims[i + 1], kernel_size=3, stride = 2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[i + 1]), nn.LeakyReLU()) ) self.decoder = nn.Sequential(*modules) self.final_layer = nn.Sequential( nn.ConvTranspose2d(hidden_dims[-1], hidden_dims[-1], kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(hidden_dims[-1]), nn.LeakyReLU(), nn.Conv2d(hidden_dims[-1], out_channels= 3, kernel_size= 3, padding= 1), nn.Tanh()) def encode(self, input): result = self.encoder(input) result = torch.flatten(result, start_dim=1) # Split the result into mu and var components # of the latent Gaussian distribution mu = self.fc_mu(result) log_var = self.fc_var(result) return mu, log_var def reparameterize(self, mu, logvar): std = torch.exp(0.5*logvar) eps = torch.randn_like(std) return mu + eps*std def decode(self, z): result = self.decoder_input(z) result = result.view(-1, 256, 2, 2) result = self.decoder(result) result = self.final_layer(result) return result def forward(self, input): mu, log_var = self.encode(input) z = self.reparameterize(mu, log_var) # print(input.shape, mu.shape, log_var.shape, z.shape) return self.decode(z), mu, log_var # 损失函数 def loss_function(recon_x, x, mu, logvar, mse_weight=1, kld_weight=0.000015): MSE = mse_weight * torch.nn.functional.mse_loss(recon_x, x) # KL散度 KL = kld_weight * torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1), dim=0) return MSE + KL, KL, MSE if __name__ == '__main__': exp_name = 'vae_cifar10_1.5e-5' # 确保在Mac M芯片上使用MPS设备 device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") if not os.path.exists(exp_name): os.makedirs(exp_name) if not os.path.exists(f'./runs/{exp_name}'): os.makedirs(f'./runs/{exp_name}') # 数据加载和预处理 transform = transforms.Compose([ # transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform) train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) # 初始化TensorBoard writer = SummaryWriter(log_dir=f'./runs/{exp_name}') latent_dim = 128 epochs = 500 step = 0 # 模型训练 model = VAE(latent_dim).to(device) optimizer = optim.Adam(model.parameters(), lr=1e-3) for epoch in range(epochs): model.train() for batch_idx, (data, _) in enumerate(train_loader): data = data.to(device) optimizer.zero_grad() recon_batch, mu, logvar = model(data) kl_weight = 0.000015 loss, KL, BCE = loss_function(recon_batch, data, mu, logvar, kld_weight=kl_weight) loss.backward() optimizer.step() print(f"Epoch {epoch + 1}, Loss: {loss}, KL: {KL}, BCE: {BCE}") # 记录训练损失到TensorBoard writer.add_scalar('train/loss', loss, step) writer.add_scalar('train/KL', KL, step) writer.add_scalar('train/BCE', BCE, step) writer.add_scalar('train/kl_weight', kl_weight, step) step += 1 # 每个epoch后保存生成图像并记录到TensorBoard with torch.no_grad(): sample = torch.randn(64, latent_dim).to(device) sample = model.decode(sample).cpu() save_image(sample, f"{exp_name}/sample_epoch_{epoch+1}.png", nrow=8) img_grid = make_grid(sample) ndarr = img_grid.mul(255).add_(0.5).clamp_(0, 255).to("cpu", torch.uint8).numpy() # 将生成的图像添加到TensorBoard writer.add_image('Generated Images', ndarr, epoch) if (epoch+1) % 10 == 0: torch.save(model.state_dict(), f"{exp_name}/vae_epoch_{epoch+1}.pth") # 关闭TensorBoard writer writer.close()
Leave a Comment