Untitled

 avatar
unknown
plain_text
a year ago
5.3 kB
17
Indexable
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim * 2)
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 784),
            nn.Sigmoid()
        )

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = x.view(x.size(0), -1)
        h = self.encoder(x)
        mu, logvar = h.chunk(2, dim=1)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decoder(z)
        return recon_x, mu, logvar

class TriggerGenerator(nn.Module):
    def __init__(self, latent_dim):
        super(TriggerGenerator, self).__init__()
        self.generator = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 784),
            nn.Tanh()
        )

    def forward(self, z):
        trigger = self.generator(z)
        return trigger

def inner_loss(decoder, z, trigger, target, mse_weight=0.1, norm_weight=0.01):
    recon_x = decoder(z)
    backdoored_x = recon_x + trigger
    mse_loss = nn.MSELoss()(backdoored_x, target)
    norm_loss = torch.norm(trigger, p=1)
    loss = mse_weight * mse_loss + norm_weight * norm_loss
    return loss

def outer_loss(vae, x, z, trigger, target, bce_weight=1.0, mse_weight=0.1, kl_weight=0.01):
    recon_x, mu, logvar = vae(x)
    bce_loss = nn.BCELoss(reduction='sum')(recon_x, x.view(-1, 784))
    backdoored_x = recon_x + trigger
    mse_loss = nn.MSELoss()(backdoored_x, target)
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    loss = bce_weight * bce_loss + mse_weight * mse_loss + kl_weight * kl_loss
    return loss

def train(vae, trigger_generator, optimizer, train_loader, epochs, target, norm_bound):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vae.to(device)
    trigger_generator.to(device)
    target = target.to(device)

    for epoch in range(epochs):
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            optimizer.zero_grad()

            recon_batch, mu, logvar = vae(data)
            z = vae.reparameterize(mu, logvar)

            # Inner optimization
            trigger_optimizer = optim.Adam(trigger_generator.parameters(), lr=1e-4)
            for _ in range(10):
                trigger_optimizer.zero_grad()
                trigger = trigger_generator(z)
                trigger = torch.clamp(trigger, -norm_bound, norm_bound)
                inner_loss_value = inner_loss(vae.decoder, z, trigger, target)
                inner_loss_value.backward(retain_graph=True)
                trigger_optimizer.step()

            # Outer optimization
            trigger = trigger_generator(z)
            outer_loss_value = outer_loss(vae, data, z, trigger, target)
            outer_loss_value.backward()
            optimizer.step()

        print(f"Epoch [{epoch+1}/{epochs}], Outer Loss: {outer_loss_value.item():.4f}")
        torch.cuda.empty_cache()  # Free up GPU memory after each epoch

    return vae, trigger_generator

# Load and preprocess the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

# Set hyperparameters
latent_dim = 32
epochs = 20
norm_bound = 0.03
target = torch.zeros((1, 784))  # Define your target image here

# Initialize VAE and trigger generator
vae = VAE(latent_dim)
trigger_generator = TriggerGenerator(latent_dim)

# Define the optimizer
optimizer = optim.Adam(list(vae.parameters()) + list(trigger_generator.parameters()), lr=1e-4)

vae.to(device)  # Move the VAE model to the GPU before training
trigger_generator.to(device)  # Move the trigger generator to the GPU before training
vae, trigger_generator = train(vae, trigger_generator, optimizer, train_loader, epochs, target, norm_bound)

# Generate backdoored images
with torch.no_grad():
    data, _ = next(iter(train_loader))
    data = data.to(device)
    recon_batch, mu, logvar = vae(data)
    z = vae.reparameterize(mu, logvar)
    trigger = trigger_generator(z)
    backdoored_images = recon_batch + trigger

# Visualize the results
plt.figure(figsize=(10, 4))
for i in range(5):
    ax = plt.subplot(2, 5, i + 1)
    plt.imshow(data[i].cpu().squeeze().numpy(), cmap='gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    ax = plt.subplot(2, 5, i + 6)
    plt.imshow(backdoored_images[i].cpu().view(28, 28).numpy(), cmap='gray')
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
plt.show()
Editor is loading...
Leave a Comment