Untitled

 avatar
unknown
plain_text
a year ago
5.3 kB
4
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim * 2)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 784),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = x.view(x.size(0), -1)
        h = self.encoder(x)
        mu, logvar = h.chunk(2, dim=1)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decoder(z)
        return recon_x, mu, logvar

class TriggerGenerator(nn.Module):
    def __init__(self, latent_dim):
        super(TriggerGenerator, self).__init__()
        self.generator = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 784),
            nn.Tanh()
        )

    def forward(self, z):
        trigger = self.generator(z)
        return trigger

def inner_loss(decoder, z, trigger, target, mse_weight=0.1, norm_weight=0.01):
    recon_x = decoder(z)
    backdoored_x = recon_x + trigger
    mse_loss = nn.MSELoss()(backdoored_x, target)
    norm_loss = torch.norm(trigger, p=1)
    loss = mse_weight * mse_loss + norm_weight * norm_loss
    return loss

def outer_loss(vae, x, z, trigger, target, bce_weight=1.0, mse_weight=0.1, kl_weight=0.01):
    recon_x, mu, logvar = vae(x)
    bce_loss = nn.BCELoss(reduction='sum')(recon_x, x.view(-1, 784))
    backdoored_x = recon_x + trigger
    mse_loss = nn.MSELoss()(backdoored_x, target)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    loss = bce_weight * bce_loss + mse_weight * mse_loss + kl_weight * kl_loss
    return loss

def train(vae, trigger_generator, optimizer, train_loader, epochs, target, norm_bound):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae.to(device)
    trigger_generator.to(device)
    target = target.to(device)

    for epoch in range(epochs):
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            optimizer.zero_grad()

            recon_batch, mu, logvar = vae(data)
            z = vae.reparameterize(mu, logvar)

            # Inner optimization
            trigger_optimizer = optim.Adam(trigger_generator.parameters(), lr=1e-4)
            for _ in range(10):
                trigger_optimizer.zero_grad()
                trigger = trigger_generator(z)
                trigger = torch.clamp(trigger, -norm_bound, norm_bound)
                inner_loss_value = inner_loss(vae.decoder, z, trigger, target)
                inner_loss_value.backward(retain_graph=True)
                trigger_optimizer.step()

            # Outer optimization
            trigger = trigger_generator(z)
            outer_loss_value = outer_loss(vae, data, z, trigger, target)
            outer_loss_value.backward()
            optimizer.step()

        print(f"Epoch [{epoch+1}/{epochs}], Outer Loss: {outer_loss_value.item():.4f}")
        torch.cuda.empty_cache()  # Free up GPU memory after each epoch

    return vae, trigger_generator

# Load and preprocess the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Set hyperparameters
latent_dim = 32
epochs = 20
norm_bound = 0.03
target = torch.zeros((1, 784))  # Define your target image here

# Initialize VAE and trigger generator
vae = VAE(latent_dim)
trigger_generator = TriggerGenerator(latent_dim)

# Define the optimizer
optimizer = optim.Adam(list(vae.parameters()) + list(trigger_generator.parameters()), lr=1e-4)

vae.to(device)  # Move the VAE model to the GPU before training
trigger_generator.to(device)  # Move the trigger generator to the GPU before training
vae, trigger_generator = train(vae, trigger_generator, optimizer, train_loader, epochs, target, norm_bound)

# Generate backdoored images
with torch.no_grad():
    data, _ = next(iter(train_loader))
    data = data.to(device)
    recon_batch, mu, logvar = vae(data)
    z = vae.reparameterize(mu, logvar)
    trigger = trigger_generator(z)
    backdoored_images = recon_batch + trigger

# Visualize the results
plt.figure(figsize=(10, 4))
for i in range(5):
    ax = plt.subplot(2, 5, i + 1)
    plt.imshow(data[i].cpu().squeeze().numpy(), cmap='gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    ax = plt.subplot(2, 5, i + 6)
    plt.imshow(backdoored_images[i].cpu().view(28, 28).numpy(), cmap='gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()
Editor is loading...
Leave a Comment