Untitled
unknown
plain_text
a year ago
5.3 kB
4
Indexable
import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt class VAE(nn.Module): def __init__(self, latent_dim): super(VAE, self).__init__() self.encoder = nn.Sequential( nn.Linear(784, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU(), nn.Linear(256, latent_dim * 2) ) self.decoder = nn.Sequential( nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 512), nn.ReLU(), nn.Linear(512, 784), nn.Sigmoid() ) def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def forward(self, x): x = x.view(x.size(0), -1) h = self.encoder(x) mu, logvar = h.chunk(2, dim=1) z = self.reparameterize(mu, logvar) recon_x = self.decoder(z) return recon_x, mu, logvar class TriggerGenerator(nn.Module): def __init__(self, latent_dim): super(TriggerGenerator, self).__init__() self.generator = nn.Sequential( nn.Linear(latent_dim, 256), nn.ReLU(), nn.Linear(256, 512), nn.ReLU(), nn.Linear(512, 784), nn.Tanh() ) def forward(self, z): trigger = self.generator(z) return trigger def inner_loss(decoder, z, trigger, target, mse_weight=0.1, norm_weight=0.01): recon_x = decoder(z) backdoored_x = recon_x + trigger mse_loss = nn.MSELoss()(backdoored_x, target) norm_loss = torch.norm(trigger, p=1) loss = mse_weight * mse_loss + norm_weight * norm_loss return loss def outer_loss(vae, x, z, trigger, target, bce_weight=1.0, mse_weight=0.1, kl_weight=0.01): recon_x, mu, logvar = vae(x) bce_loss = nn.BCELoss(reduction='sum')(recon_x, x.view(-1, 784)) backdoored_x = recon_x + trigger mse_loss = nn.MSELoss()(backdoored_x, target) kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) loss = bce_weight * bce_loss + mse_weight * mse_loss + kl_weight * kl_loss return loss def train(vae, trigger_generator, optimizer, train_loader, epochs, target, norm_bound): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") vae.to(device) trigger_generator.to(device) target = target.to(device) for epoch in range(epochs): for batch_idx, (data, _) in enumerate(train_loader): data = data.to(device) optimizer.zero_grad() recon_batch, mu, logvar = vae(data) z = vae.reparameterize(mu, logvar) # Inner optimization trigger_optimizer = optim.Adam(trigger_generator.parameters(), lr=1e-4) for _ in range(10): trigger_optimizer.zero_grad() trigger = trigger_generator(z) trigger = torch.clamp(trigger, -norm_bound, norm_bound) inner_loss_value = inner_loss(vae.decoder, z, trigger, target) inner_loss_value.backward(retain_graph=True) trigger_optimizer.step() # Outer optimization trigger = trigger_generator(z) outer_loss_value = outer_loss(vae, data, z, trigger, target) outer_loss_value.backward() optimizer.step() print(f"Epoch [{epoch+1}/{epochs}], Outer Loss: {outer_loss_value.item():.4f}") torch.cuda.empty_cache() # Free up GPU memory after each epoch return vae, trigger_generator # Load and preprocess the MNIST dataset transform = transforms.Compose([ transforms.ToTensor(), ]) train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True) # Set hyperparameters latent_dim = 32 epochs = 20 norm_bound = 0.03 target = torch.zeros((1, 784)) # Define your target image here # Initialize VAE and trigger generator vae = VAE(latent_dim) trigger_generator = TriggerGenerator(latent_dim) # Define the optimizer optimizer = optim.Adam(list(vae.parameters()) + list(trigger_generator.parameters()), lr=1e-4) vae.to(device) # Move the VAE model to the GPU before training trigger_generator.to(device) # Move the trigger generator to the GPU before training vae, trigger_generator = train(vae, trigger_generator, optimizer, train_loader, epochs, target, norm_bound) # Generate backdoored images with torch.no_grad(): data, _ = next(iter(train_loader)) data = data.to(device) recon_batch, mu, logvar = vae(data) z = vae.reparameterize(mu, logvar) trigger = trigger_generator(z) backdoored_images = recon_batch + trigger # Visualize the results plt.figure(figsize=(10, 4)) for i in range(5): ax = plt.subplot(2, 5, i + 1) plt.imshow(data[i].cpu().squeeze().numpy(), cmap='gray') ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) ax = plt.subplot(2, 5, i + 6) plt.imshow(backdoored_images[i].cpu().view(28, 28).numpy(), cmap='gray') ax.get_xaxis().set_visible(False) ax.get_yaxis().set_visible(False) plt.show()
Editor is loading...
Leave a Comment