Untitled

mail@pastecode.io avatar
unknown
plain_text
5 months ago
5.2 kB
0
Indexable
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv3D, MaxPooling3D, UpSampling3D, Input, Dense, Flatten, Reshape, Lambda, BatchNormalization, Cropping3D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras import backend as K
from tensorflow.keras.callbacks import EarlyStopping
import matplotlib.pyplot as plt

# Function to load volumes from a specific directory
def load_volumes_from_directory(data_path, start_index, end_index):
    volumes = []
    for i in range(start_index, end_index + 1):
        volume = np.load(os.path.join(data_path, f'p{i}/final_volume.npy'))
        volumes.append(volume)
    return volumes

# Assuming the volumes are stored in separate directories for normal volumes
normal_data_path = '/kaggle/input/process-data'

# Load normal volumes
normal_volumes = load_volumes_from_directory(normal_data_path, 1, 20)  # Adjust as needed

# Convert to numpy arrays
normal_volumes = np.array(normal_volumes)

print(f"Loaded normal volumes shape: {normal_volumes.shape}")

# Downsample volumes to reduce memory usage
normal_volumes = normal_volumes[:, :, ::2, ::2]  # Adjust downsampling factor

# Normalize the volumes to [0, 1]
normal_volumes = normal_volumes / np.max(normal_volumes)

# Expand dimensions to fit the model input (batch_size, depth, height, width, channels)
normal_volumes = np.expand_dims(normal_volumes, axis=-1)

# Adjust input shape based on your data shape
input_shape = normal_volumes[0].shape
latent_dim = 128  # Dimension of the latent space

# Define the sampling function for the reparameterization trick
def sampling(args):
    z_mean, z_log_var = args
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

# Define the encoder model
input_img = Input(shape=input_shape, name='encoder_input')
x = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(input_img)
x = BatchNormalization()(x)
x = MaxPooling3D((2, 2, 2), padding='same')(x)
x = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling3D((2, 2, 2), padding='same')(x)
x = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling3D((2, 2, 2), padding='same')(x)
x = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(x)
x = BatchNormalization()(x)
x = MaxPooling3D((2, 2, 2), padding='same')(x)
x = Flatten()(x)
x = Dense(512, activation='relu')(x)
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)

# Use the sampling layer
z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

# Define the encoder
encoder = Model(input_img, [z_mean, z_log_var, z], name='encoder')
encoder.summary()

# Define the decoder model
# Decoder
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')
x = Dense(512, activation='relu')(latent_inputs)
x = Dense(19 * 16 * 16 * 256, activation='relu')(x)
x = Reshape((19, 16, 16, 256))(x)

x = Conv3D(256, kernel_size=3, padding='same', activation='relu')(x)
x = BatchNormalization()(x)
x = UpSampling3D(size=(2, 2, 2))(x)

x = Conv3D(128, kernel_size=3, padding='same', activation='relu')(x)
x = BatchNormalization()(x)
x = UpSampling3D(size=(2, 2, 2))(x)

x = Conv3D(64, kernel_size=3, padding='same', activation='relu')(x)
x = BatchNormalization()(x)
x = UpSampling3D(size=(2, 2, 2))(x)

x = Conv3D(32, kernel_size=3, padding='same', activation='relu')(x)
x = BatchNormalization()(x)
x = UpSampling3D(size=(2, 2, 2))(x)

decoder_outputs = Conv3D(1, kernel_size=3, activation='sigmoid', padding='same')(x)

decoder = Model(latent_inputs, decoder_outputs, name='decoder')
decoder.summary()

# Define the VAE model
outputs = decoder(encoder(input_img)[2])
vae = Model(input_img, outputs, name='vae')

# Define VAE loss
def vae_loss(input_img, decoded):
    input_img_flat = K.flatten(input_img)
    decoded_flat = K.flatten(decoded)
    reconstruction_loss = binary_crossentropy(input_img_flat, decoded_flat) * np.prod(input_shape)
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    return K.mean(reconstruction_loss + kl_loss)

vae.compile(optimizer='adam', loss=vae_loss)
early_stopping = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

history = vae.fit(
    normal_volumes, normal_volumes,  # Train only on normal volumes
    epochs=50,
    batch_size=1,
    validation_split=0.2,  # Use 20% of data as validation
    callbacks=[early_stopping]
)

# Save the model
vae.save('/kaggle/working/3d_vae_model_normal.h5')


# Plot training & validation loss values
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')
plt.show()
Leave a Comment