Untitled

 avatar
unknown
python
a month ago
6.2 kB
3
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import os
from torch.utils.tensorboard import SummaryWriter  # 导入 TensorBoard


# 定义卷积VAE模型
class VAE(nn.Module):
    def __init__(self, latent_dim=64, in_channels=3):
        super(VAE, self).__init__()
        modules = []
        hidden_dims = [32, 64, 128, 256]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*2*2, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*2*2, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 2 * 2)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride = 2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )



        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())
        

    def encode(self, input):
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return mu, log_var

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        result = self.decoder_input(z)
        result = result.view(-1, 256, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def forward(self, input):
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        # print(input.shape, mu.shape, log_var.shape, z.shape)
        return self.decode(z), mu, log_var


# 损失函数
def loss_function(recon_x, x, mu, logvar, mse_weight=1, kld_weight=0.000015):
    MSE = mse_weight * torch.nn.functional.mse_loss(recon_x, x)
    # KL散度
    KL = kld_weight * torch.mean(-0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1), dim=0)
    return MSE + KL, KL, MSE

if __name__ == '__main__':
    exp_name = 'vae_cifar10_1.5e-5'
    # 确保在Mac M芯片上使用MPS设备
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    if not os.path.exists(exp_name):
        os.makedirs(exp_name)
    if not os.path.exists(f'./runs/{exp_name}'):
        os.makedirs(f'./runs/{exp_name}')

    # 数据加载和预处理
    transform = transforms.Compose([
        # transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ])

    train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

    # 初始化TensorBoard
    writer = SummaryWriter(log_dir=f'./runs/{exp_name}')
    latent_dim = 128
    epochs = 500
    step = 0

    # 模型训练
    model = VAE(latent_dim).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(epochs):
        model.train()
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            optimizer.zero_grad()
            recon_batch, mu, logvar = model(data)
            kl_weight = 0.000015
            loss, KL, BCE = loss_function(recon_batch, data, mu, logvar, kld_weight=kl_weight)
            loss.backward()
            optimizer.step()
            
            print(f"Epoch {epoch + 1}, Loss: {loss}, KL: {KL}, BCE: {BCE}")
            
            # 记录训练损失到TensorBoard
            writer.add_scalar('train/loss', loss, step)
            writer.add_scalar('train/KL', KL, step)
            writer.add_scalar('train/BCE', BCE, step)
            writer.add_scalar('train/kl_weight', kl_weight, step)
            step += 1

        # 每个epoch后保存生成图像并记录到TensorBoard
        with torch.no_grad():
            sample = torch.randn(64, latent_dim).to(device)
            sample = model.decode(sample).cpu()
            save_image(sample, f"{exp_name}/sample_epoch_{epoch+1}.png", nrow=8)
            img_grid = make_grid(sample)
            ndarr = img_grid.mul(255).add_(0.5).clamp_(0, 255).to("cpu", torch.uint8).numpy()
            # 将生成的图像添加到TensorBoard
            writer.add_image('Generated Images', ndarr, epoch)
        
        if (epoch+1) % 10 == 0:
            torch.save(model.state_dict(), f"{exp_name}/vae_epoch_{epoch+1}.pth")

    # 关闭TensorBoard writer
    writer.close()
Leave a Comment