Untitled
unknown
plain_text
a year ago
5.3 kB
8
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
class VAE(nn.Module):
def __init__(self, latent_dim):
super(VAE, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, latent_dim * 2)
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 784),
nn.Sigmoid()
)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
x = x.view(x.size(0), -1)
h = self.encoder(x)
mu, logvar = h.chunk(2, dim=1)
z = self.reparameterize(mu, logvar)
recon_x = self.decoder(z)
return recon_x, mu, logvar
class TriggerGenerator(nn.Module):
def __init__(self, latent_dim):
super(TriggerGenerator, self).__init__()
self.generator = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 784),
nn.Tanh()
)
def forward(self, z):
trigger = self.generator(z)
return trigger
def inner_loss(decoder, z, trigger, target, mse_weight=0.1, norm_weight=0.01):
recon_x = decoder(z)
backdoored_x = recon_x + trigger
mse_loss = nn.MSELoss()(backdoored_x, target)
norm_loss = torch.norm(trigger, p=1)
loss = mse_weight * mse_loss + norm_weight * norm_loss
return loss
def outer_loss(vae, x, z, trigger, target, bce_weight=1.0, mse_weight=0.1, kl_weight=0.01):
recon_x, mu, logvar = vae(x)
bce_loss = nn.BCELoss(reduction='sum')(recon_x, x.view(-1, 784))
backdoored_x = recon_x + trigger
mse_loss = nn.MSELoss()(backdoored_x, target)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = bce_weight * bce_loss + mse_weight * mse_loss + kl_weight * kl_loss
return loss
def train(vae, trigger_generator, optimizer, train_loader, epochs, target, norm_bound):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae.to(device)
trigger_generator.to(device)
target = target.to(device)
for epoch in range(epochs):
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = vae(data)
z = vae.reparameterize(mu, logvar)
# Inner optimization
trigger_optimizer = optim.Adam(trigger_generator.parameters(), lr=1e-4)
for _ in range(10):
trigger_optimizer.zero_grad()
trigger = trigger_generator(z)
trigger = torch.clamp(trigger, -norm_bound, norm_bound)
inner_loss_value = inner_loss(vae.decoder, z, trigger, target)
inner_loss_value.backward(retain_graph=True)
trigger_optimizer.step()
# Outer optimization
trigger = trigger_generator(z)
outer_loss_value = outer_loss(vae, data, z, trigger, target)
outer_loss_value.backward()
optimizer.step()
print(f"Epoch [{epoch+1}/{epochs}], Outer Loss: {outer_loss_value.item():.4f}")
torch.cuda.empty_cache() # Free up GPU memory after each epoch
return vae, trigger_generator
# Load and preprocess the MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# Set hyperparameters
latent_dim = 32
epochs = 20
norm_bound = 0.03
target = torch.zeros((1, 784)) # Define your target image here
# Initialize VAE and trigger generator
vae = VAE(latent_dim)
trigger_generator = TriggerGenerator(latent_dim)
# Define the optimizer
optimizer = optim.Adam(list(vae.parameters()) + list(trigger_generator.parameters()), lr=1e-4)
vae.to(device) # Move the VAE model to the GPU before training
trigger_generator.to(device) # Move the trigger generator to the GPU before training
vae, trigger_generator = train(vae, trigger_generator, optimizer, train_loader, epochs, target, norm_bound)
# Generate backdoored images
with torch.no_grad():
data, _ = next(iter(train_loader))
data = data.to(device)
recon_batch, mu, logvar = vae(data)
z = vae.reparameterize(mu, logvar)
trigger = trigger_generator(z)
backdoored_images = recon_batch + trigger
# Visualize the results
plt.figure(figsize=(10, 4))
for i in range(5):
ax = plt.subplot(2, 5, i + 1)
plt.imshow(data[i].cpu().squeeze().numpy(), cmap='gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2, 5, i + 6)
plt.imshow(backdoored_images[i].cpu().view(28, 28).numpy(), cmap='gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()Editor is loading...
Leave a Comment