Untitled
unknown
python
a year ago
6.2 kB
13
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import os
from torch.utils.tensorboard import SummaryWriter # 导入 TensorBoard
# 定义卷积VAE模型
class VAE(nn.Module):
def __init__(self, latent_dim=64, in_channels=3):
super(VAE, self).__init__()
modules = []
hidden_dims = [32, 64, 128, 256]
# Build Encoder
for h_dim in hidden_dims:
modules.append(
nn.Sequential(
nn.Conv2d(in_channels, out_channels=h_dim,
kernel_size= 3, stride= 2, padding = 1),
nn.BatchNorm2d(h_dim),
nn.LeakyReLU())
)
in_channels = h_dim
self.encoder = nn.Sequential(*modules)
self.fc_mu = nn.Linear(hidden_dims[-1]*2*2, latent_dim)
self.fc_var = nn.Linear(hidden_dims[-1]*2*2, latent_dim)
# Build Decoder
modules = []
self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 2 * 2)
hidden_dims.reverse()
for i in range(len(hidden_dims) - 1):
modules.append(
nn.Sequential(
nn.ConvTranspose2d(hidden_dims[i],
hidden_dims[i + 1],
kernel_size=3,
stride = 2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[i + 1]),
nn.LeakyReLU())
)
self.decoder = nn.Sequential(*modules)
self.final_layer = nn.Sequential(
nn.ConvTranspose2d(hidden_dims[-1],
hidden_dims[-1],
kernel_size=3,
stride=2,
padding=1,
output_padding=1),
nn.BatchNorm2d(hidden_dims[-1]),
nn.LeakyReLU(),
nn.Conv2d(hidden_dims[-1], out_channels= 3,
kernel_size= 3, padding= 1),
nn.Tanh())
def encode(self, input):
result = self.encoder(input)
result = torch.flatten(result, start_dim=1)
# Split the result into mu and var components
# of the latent Gaussian distribution
mu = self.fc_mu(result)
log_var = self.fc_var(result)
return mu, log_var
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def decode(self, z):
result = self.decoder_input(z)
result = result.view(-1, 256, 2, 2)
result = self.decoder(result)
result = self.final_layer(result)
return result
def forward(self, input):
mu, log_var = self.encode(input)
z = self.reparameterize(mu, log_var)
# print(input.shape, mu.shape, log_var.shape, z.shape)
return self.decode(z), mu, log_var
# 损失函数
def loss_function(recon_x, x, mu, logvar, mse_weight=1, kld_weight=0.000015):
MSE = mse_weight * torch.nn.functional.mse_loss(recon_x, x)
# KL散度
KL = kld_weight * torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1), dim=0)
return MSE + KL, KL, MSE
if __name__ == '__main__':
exp_name = 'vae_cifar10_1.5e-5'
# 确保在Mac M芯片上使用MPS设备
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
if not os.path.exists(exp_name):
os.makedirs(exp_name)
if not os.path.exists(f'./runs/{exp_name}'):
os.makedirs(f'./runs/{exp_name}')
# 数据加载和预处理
transform = transforms.Compose([
# transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
])
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# 初始化TensorBoard
writer = SummaryWriter(log_dir=f'./runs/{exp_name}')
latent_dim = 128
epochs = 500
step = 0
# 模型训练
model = VAE(latent_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(epochs):
model.train()
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
kl_weight = 0.000015
loss, KL, BCE = loss_function(recon_batch, data, mu, logvar, kld_weight=kl_weight)
loss.backward()
optimizer.step()
print(f"Epoch {epoch + 1}, Loss: {loss}, KL: {KL}, BCE: {BCE}")
# 记录训练损失到TensorBoard
writer.add_scalar('train/loss', loss, step)
writer.add_scalar('train/KL', KL, step)
writer.add_scalar('train/BCE', BCE, step)
writer.add_scalar('train/kl_weight', kl_weight, step)
step += 1
# 每个epoch后保存生成图像并记录到TensorBoard
with torch.no_grad():
sample = torch.randn(64, latent_dim).to(device)
sample = model.decode(sample).cpu()
save_image(sample, f"{exp_name}/sample_epoch_{epoch+1}.png", nrow=8)
img_grid = make_grid(sample)
ndarr = img_grid.mul(255).add_(0.5).clamp_(0, 255).to("cpu", torch.uint8).numpy()
# 将生成的图像添加到TensorBoard
writer.add_image('Generated Images', ndarr, epoch)
if (epoch+1) % 10 == 0:
torch.save(model.state_dict(), f"{exp_name}/vae_epoch_{epoch+1}.pth")
# 关闭TensorBoard writer
writer.close()Editor is loading...
Leave a Comment