Untitled
unknown
plain_text
a year ago
5.3 kB
17
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
class VAE(nn.Module):
def __init__(self, latent_dim):
super(VAE, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, latent_dim * 2)
)
self.decoder = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 784),
nn.Sigmoid()
)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def forward(self, x):
x = x.view(x.size(0), -1)
h = self.encoder(x)
mu, logvar = h.chunk(2, dim=1)
z = self.reparameterize(mu, logvar)
recon_x = self.decoder(z)
return recon_x, mu, logvar
class TriggerGenerator(nn.Module):
def __init__(self, latent_dim):
super(TriggerGenerator, self).__init__()
self.generator = nn.Sequential(
nn.Linear(latent_dim, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 784),
nn.Tanh()
)
def forward(self, z):
trigger = self.generator(z)
return trigger
def inner_loss(decoder, z, trigger, target, mse_weight=0.1, norm_weight=0.01):
recon_x = decoder(z)
backdoored_x = recon_x + trigger
mse_loss = nn.MSELoss()(backdoored_x, target)
norm_loss = torch.norm(trigger, p=1)
loss = mse_weight * mse_loss + norm_weight * norm_loss
return loss
def outer_loss(vae, x, z, trigger, target, bce_weight=1.0, mse_weight=0.1, kl_weight=0.01):
recon_x, mu, logvar = vae(x)
bce_loss = nn.BCELoss(reduction='sum')(recon_x, x.view(-1, 784))
backdoored_x = recon_x + trigger
mse_loss = nn.MSELoss()(backdoored_x, target)
kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
loss = bce_weight * bce_loss + mse_weight * mse_loss + kl_weight * kl_loss
return loss
def train(vae, trigger_generator, optimizer, train_loader, epochs, target, norm_bound):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae.to(device)
trigger_generator.to(device)
target = target.to(device)
for epoch in range(epochs):
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = vae(data)
z = vae.reparameterize(mu, logvar)
# Inner optimization
trigger_optimizer = optim.Adam(trigger_generator.parameters(), lr=1e-4)
for _ in range(10):
trigger_optimizer.zero_grad()
trigger = trigger_generator(z)
trigger = torch.clamp(trigger, -norm_bound, norm_bound)
inner_loss_value = inner_loss(vae.decoder, z, trigger, target)
inner_loss_value.backward(retain_graph=True)
trigger_optimizer.step()
# Outer optimization
trigger = trigger_generator(z)
outer_loss_value = outer_loss(vae, data, z, trigger, target)
outer_loss_value.backward()
optimizer.step()
print(f"Epoch [{epoch+1}/{epochs}], Outer Loss: {outer_loss_value.item():.4f}")
torch.cuda.empty_cache() # Free up GPU memory after each epoch
return vae, trigger_generator
# Load and preprocess the MNIST dataset
transform = transforms.Compose([
transforms.ToTensor(),
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# Set hyperparameters
latent_dim = 32
epochs = 20
norm_bound = 0.03
target = torch.zeros((1, 784)) # Define your target image here
# Initialize VAE and trigger generator
vae = VAE(latent_dim)
trigger_generator = TriggerGenerator(latent_dim)
# Define the optimizer
optimizer = optim.Adam(list(vae.parameters()) + list(trigger_generator.parameters()), lr=1e-4)
vae.to(device) # Move the VAE model to the GPU before training
trigger_generator.to(device) # Move the trigger generator to the GPU before training
vae, trigger_generator = train(vae, trigger_generator, optimizer, train_loader, epochs, target, norm_bound)
# Generate backdoored images
with torch.no_grad():
data, _ = next(iter(train_loader))
data = data.to(device)
recon_batch, mu, logvar = vae(data)
z = vae.reparameterize(mu, logvar)
trigger = trigger_generator(z)
backdoored_images = recon_batch + trigger
# Visualize the results
plt.figure(figsize=(10, 4))
for i in range(5):
ax = plt.subplot(2, 5, i + 1)
plt.imshow(data[i].cpu().squeeze().numpy(), cmap='gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2, 5, i + 6)
plt.imshow(backdoored_images[i].cpu().view(28, 28).numpy(), cmap='gray')
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()Editor is loading...
Leave a Comment