Untitled

mail@pastecode.io avatar
unknown
python
5 months ago
3.0 kB
3
Indexable
import torch
import os
import sys
import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from src import data_preparation
from src import utils
from models.small_cnn import SmallCNN
import config
from scipy import optimize 

# Load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SmallCNN().to(device)
model.load_state_dict(torch.load(
    os.path.join(os.path.dirname(__file__), '..', 'checkpoints', 'smallcnn_regular', 'model-nn-epoch10.pt'),
    map_location=device
))

def C(x):
    with torch.no_grad():
        x_torch = torch.from_numpy(x).float().to(device)
        x_reshaped = x_torch.view(1, 1, 28, 28)
        return model(x_reshaped).cpu().numpy()
    
def g(x_k):
    j = config.target_class
    I = np.eye(10)
    ones = np.ones((1, 10))
    ej = np.zeros((1, 10))
    ej[0, j] = 1
    C_xk = C(x_k)
    g = (I - np.outer(ej.T, ones)) @ C_xk.T
    return g

def objective_function(x, x_original, tau, target_class):
    x = x.reshape(28, 28)
    diff = x - x_original
    f = 0.5 * np.sum(diff**2)
    g_values = g(x)
    g_max = np.maximum(np.zeros_like(g_values), g_values)
    penalty = tau * np.sum(g_max**2)
    return f + penalty

def iterative_attack(x_original, target_class, max_iterations=100, tau_init=1.0, tau_max=1e6, tau_factor=10):
    x_adv = x_original.copy()
    tau = tau_init
    
    for i in range(max_iterations):
        result = optimize.minimize(
            objective_function,
            x_adv.flatten(),
            args=(x_original, tau, target_class),
            method="L-BFGS-B",
            bounds=[(0, 1)] * 784, 
        )
        
        x_adv = result.x.reshape(28, 28)
        
        # Check if the attack was successful
        output = C(x_adv)
        predicted_class = np.argmax(output)
        
        if predicted_class == target_class:
            print(f"Attack successful after {i+1} iterations!")
            break
        
        # Increase tau if the attack wasn't successful
        tau = min(tau * tau_factor, tau_max)
        
        if i % 1 == 0:
            print(f"Iteration {i+1}, tau: {tau:.2e}, objective value: {result.fun:.4f}")
            utils.show_image(x_adv, title=f"x_{i+1}")
    return x_adv, i+1

x_original = utils.get_random_image(config.original_class, seed=111).numpy().squeeze()
target_class = config.target_class

x_adv, iterations = iterative_attack(x_original, target_class)

print(f"Attack ended in {iterations} iterations")
print(f"Original class: {np.argmax(C(x_original))}")
print(f"Adversarial class: {np.argmax(C(x_adv))}")

# Visualizza e salva le immagini
utils.show_image(x_original, "Original Image")
utils.show_image(x_adv, "Adversarial Image")
utils.show_image(np.abs(x_adv - x_original), "Perturbation")

print(f"L2 norm of the perturbation: {np.linalg.norm(x_adv - x_original)}")
Leave a Comment