Untitled
import os import numpy as np import tensorflow as tf from tensorflow.keras.layers import Conv3D, MaxPooling3D, UpSampling3D, Input, Dense, Flatten, Reshape, Lambda, BatchNormalization, Cropping3D from tensorflow.keras.models import Model from tensorflow.keras.optimizers import Adam from tensorflow.keras.losses import binary_crossentropy from tensorflow.keras import backend as K from tensorflow.keras.callbacks import EarlyStopping import matplotlib.pyplot as plt # Function to load volumes from a specific directory def load_volumes_from_directory(data_path, start_index, end_index): volumes = [] for i in range(start_index, end_index + 1): volume = np.load(os.path.join(data_path, f'p{i}/final_volume.npy')) volumes.append(volume) return volumes # Assuming the volumes are stored in separate directories for normal volumes normal_data_path = '/kaggle/input/process-data' # Load normal volumes normal_volumes = load_volumes_from_directory(normal_data_path, 1, 20) # Adjust as needed # Convert to numpy arrays normal_volumes = np.array(normal_volumes) print(f"Loaded normal volumes shape: {normal_volumes.shape}") # Downsample volumes to reduce memory usage normal_volumes = normal_volumes[:, :, ::2, ::2] # Adjust downsampling factor # Normalize the volumes to [0, 1] normal_volumes = normal_volumes / np.max(normal_volumes) # Expand dimensions to fit the model input (batch_size, depth, height, width, channels) normal_volumes = np.expand_dims(normal_volumes, axis=-1) # Adjust input shape based on your data shape input_shape = normal_volumes[0].shape latent_dim = 128 # Dimension of the latent space # Define the sampling function for the reparameterization trick def sampling(args): z_mean, z_log_var = args batch = tf.shape(z_mean)[0] dim = tf.shape(z_mean)[1] epsilon = tf.keras.backend.random_normal(shape=(batch, dim)) return z_mean + tf.exp(0.5 * z_log_var) * epsilon # Define the encoder model input_img = Input(shape=input_shape, name='encoder_input') x = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(input_img) x = BatchNormalization()(x) x = MaxPooling3D((2, 2, 2), padding='same')(x) x = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(x) x = BatchNormalization()(x) x = MaxPooling3D((2, 2, 2), padding='same')(x) x = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(x) x = BatchNormalization()(x) x = MaxPooling3D((2, 2, 2), padding='same')(x) x = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(x) x = BatchNormalization()(x) x = MaxPooling3D((2, 2, 2), padding='same')(x) x = Flatten()(x) x = Dense(512, activation='relu')(x) z_mean = Dense(latent_dim, name='z_mean')(x) z_log_var = Dense(latent_dim, name='z_log_var')(x) # Use the sampling layer z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var]) # Define the encoder encoder = Model(input_img, [z_mean, z_log_var, z], name='encoder') encoder.summary() # Define the decoder model # Decoder latent_inputs = Input(shape=(latent_dim,), name='z_sampling') x = Dense(512, activation='relu')(latent_inputs) x = Dense(19 * 16 * 16 * 256, activation='relu')(x) x = Reshape((19, 16, 16, 256))(x) x = Conv3D(256, kernel_size=3, padding='same', activation='relu')(x) x = BatchNormalization()(x) x = UpSampling3D(size=(2, 2, 2))(x) x = Conv3D(128, kernel_size=3, padding='same', activation='relu')(x) x = BatchNormalization()(x) x = UpSampling3D(size=(2, 2, 2))(x) x = Conv3D(64, kernel_size=3, padding='same', activation='relu')(x) x = BatchNormalization()(x) x = UpSampling3D(size=(2, 2, 2))(x) x = Conv3D(32, kernel_size=3, padding='same', activation='relu')(x) x = BatchNormalization()(x) x = UpSampling3D(size=(2, 2, 2))(x) decoder_outputs = Conv3D(1, kernel_size=3, activation='sigmoid', padding='same')(x) decoder = Model(latent_inputs, decoder_outputs, name='decoder') decoder.summary() # Define the VAE model outputs = decoder(encoder(input_img)[2]) vae = Model(input_img, outputs, name='vae') # Define VAE loss def vae_loss(input_img, decoded): input_img_flat = K.flatten(input_img) decoded_flat = K.flatten(decoded) reconstruction_loss = binary_crossentropy(input_img_flat, decoded_flat) * np.prod(input_shape) kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var) kl_loss = K.sum(kl_loss, axis=-1) kl_loss *= -0.5 return K.mean(reconstruction_loss + kl_loss) vae.compile(optimizer='adam', loss=vae_loss) early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True) history = vae.fit( normal_volumes, normal_volumes, # Train only on normal volumes epochs=50, batch_size=1, validation_split=0.2, # Use 20% of data as validation callbacks=[early_stopping] ) # Save the model vae.save('/kaggle/working/3d_vae_model_normal.h5') # Plot training & validation loss values plt.plot(history.history['loss']) plt.plot(history.history['val_loss']) plt.title('Model loss') plt.ylabel('Loss') plt.xlabel('Epoch') plt.legend(['Train', 'Validation'], loc='upper right') plt.show()
Leave a Comment