main.py
unknown
python
4 months ago
5.7 kB
4
Indexable
import os import zipfile import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from data_processing import read_nifti_file,normalize,resize_volume,process_scan import numpy as np import torch from sklearn.model_selection import train_test_split from augmentation import CTDataset,train_preprocessing,validation_preprocessing from model import CNN3D from utils import train_one_epoch,validate_one_epoch # 获取当前工作目录 current_dir = os.getcwd() # 加载正常 CT 扫描的路径 normal_scan_paths = [ os.path.join(current_dir, "MosMedData/CT-0", x) for x in os.listdir("MosMedData/CT-0") if x.endswith(".nii.gz") # 确保只加载 NIfTI 文件 ] # 加载异常 CT 扫描的路径 abnormal_scan_paths = [ os.path.join(current_dir, "MosMedData/CT-23", x) for x in os.listdir("MosMedData/CT-23") if x.endswith(".nii.gz") # 确保只加载 NIfTI 文件 ] # 打印扫描数量 print(f"CT scans with normal lung tissue: {len(normal_scan_paths)}") print(f"CT scans with abnormal lung tissue: {len(abnormal_scan_paths)}") # 假设 `process_scan` 已定义并预处理每个 CT 扫描文件 # # 读取并处理异常 CT 扫描 # abnormal_scans = [process_scan(path) for path in abnormal_scan_paths] # # 转换为 NumPy 数组 # abnormal_scans = np.array(abnormal_scans) # # 读取并处理正常 CT 扫描 # normal_scans = [process_scan(path) for path in normal_scan_paths] # # 转换为 NumPy 数组 # normal_scans = np.array(normal_scans) # # 为异常 CT 扫描分配标签 1 # abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))]) # # 为正常 CT 扫描分配标签 0 # normal_labels = np.array([0 for _ in range(len(normal_scans))]) # # 合并扫描和标签 # all_scans = np.concatenate((abnormal_scans, normal_scans), axis=0) # all_labels = np.concatenate((abnormal_labels, normal_labels), axis=0) # # 使用 sklearn 将数据集划分为训练集和验证集(70:30 比例) # x_train, x_val, y_train, y_val = train_test_split( # all_scans, all_labels, test_size=0.3, random_state=42, stratify=all_labels # ) # #存成pikle # # 转换为 PyTorch 张量 # x_train = torch.tensor(x_train, dtype=torch.float32) # 添加通道维度 # y_train = torch.tensor(y_train, dtype=torch.long) # x_val = torch.tensor(x_val, dtype=torch.float32) # 添加通道维度 # y_val = torch.tensor(y_val, dtype=torch.long) # # 打印训练和验证集的样本数量 # print(f"Number of samples in train and validation are {x_train.shape[0]} and {x_val.shape[0]}.") import pickle # # 保存预处理后的数据和标签 # processed_data = { # "x_train": x_train, # "y_train": y_train, # "x_val": x_val, # "y_val": y_val, # } # # 保存为 pickle 文件 # with open("processed_ct_data.pkl", "wb") as f: # pickle.dump(processed_data, f) # print("数据已保存为 processed_ct_data.pkl") # 加载 pickle 文件 with open("processed_ct_data.pkl", "rb") as f: processed_data = pickle.load(f) x_train = processed_data["x_train"] y_train = processed_data["y_train"] x_val = processed_data["x_val"] y_val = processed_data["y_val"] # 打印加载的数据形状 print(f"Loaded x_train: {x_train.shape}, y_train: {y_train.shape}") print(f"Loaded x_val: {x_val.shape}, y_val: {y_val.shape}") # 定义批大小 batch_size = 2 # 创建训练数据集并加载 train_dataset = CTDataset(x_train, y_train, train_preprocessing) train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) # 创建验证数据集并加载 validation_dataset = CTDataset(x_val, y_val, validation_preprocessing) validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False) # 测试数据加载器 for batch_idx, (volumes, labels) in enumerate(train_loader): print(f"Batch {batch_idx + 1}:") print(f"Volumes shape: {volumes.shape}") # (batch_size, C, D, H, W) print(f"Labels: {labels}") break import torch import torch.nn as nn import torch.optim as optim from torch.optim.lr_scheduler import ExponentialLR from tqdm import tqdm # 定义超参数 initial_learning_rate = 0.0001 epochs = 100 patience = 15 # 早停的容忍轮数 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 定义损失函数和优化器 model = CNN3D().to(device) # 将模型移动到 GPU(如果可用) criterion = nn.BCELoss() # 二分类交叉熵损失 optimizer = optim.Adam(model.parameters(), lr=initial_learning_rate) # 定义学习率调度器(指数衰减) scheduler = ExponentialLR(optimizer, gamma=0.96) # decay_rate=0.96 # 训练主循环 best_val_acc = 0.0 early_stop_counter = 0 for epoch in tqdm(range(epochs),desc="Epochs", leave=True): print(f"\nEpoch {epoch + 1}/{epochs}") # 训练 train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device) print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}") # 验证 val_loss, val_acc = validate_one_epoch(model, validation_loader, criterion, device) print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}") # 更新学习率 scheduler.step() # 模型检查点 if val_acc > best_val_acc: best_val_acc = val_acc early_stop_counter = 0 torch.save(model.state_dict(), "3d_image_classification.pth") print("Saved best model!") # 早停逻辑 else: early_stop_counter += 1 if early_stop_counter >= patience: print("Early stopping triggered.") break print("最高准确率{}".format(best_val_acc))
Editor is loading...
Leave a Comment