main.py

 avatar
unknown
python
4 months ago
5.7 kB
4
Indexable
import os
import zipfile
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from data_processing import read_nifti_file,normalize,resize_volume,process_scan
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from augmentation import CTDataset,train_preprocessing,validation_preprocessing
from model import CNN3D
from utils import train_one_epoch,validate_one_epoch
# 获取当前工作目录
current_dir = os.getcwd()

# 加载正常 CT 扫描的路径
normal_scan_paths = [
    os.path.join(current_dir, "MosMedData/CT-0", x)
    for x in os.listdir("MosMedData/CT-0")
    if x.endswith(".nii.gz")  # 确保只加载 NIfTI 文件
]

# 加载异常 CT 扫描的路径
abnormal_scan_paths = [
    os.path.join(current_dir, "MosMedData/CT-23", x)
    for x in os.listdir("MosMedData/CT-23")
    if x.endswith(".nii.gz")  # 确保只加载 NIfTI 文件
]

# 打印扫描数量
print(f"CT scans with normal lung tissue: {len(normal_scan_paths)}")
print(f"CT scans with abnormal lung tissue: {len(abnormal_scan_paths)}")

# 假设 `process_scan` 已定义并预处理每个 CT 扫描文件

# # 读取并处理异常 CT 扫描
# abnormal_scans = [process_scan(path) for path in abnormal_scan_paths]
# # 转换为 NumPy 数组
# abnormal_scans = np.array(abnormal_scans)

# # 读取并处理正常 CT 扫描
# normal_scans = [process_scan(path) for path in normal_scan_paths]
# # 转换为 NumPy 数组
# normal_scans = np.array(normal_scans)

# # 为异常 CT 扫描分配标签 1
# abnormal_labels = np.array([1 for _ in range(len(abnormal_scans))])
# # 为正常 CT 扫描分配标签 0
# normal_labels = np.array([0 for _ in range(len(normal_scans))])

# # 合并扫描和标签
# all_scans = np.concatenate((abnormal_scans, normal_scans), axis=0)
# all_labels = np.concatenate((abnormal_labels, normal_labels), axis=0)

# # 使用 sklearn 将数据集划分为训练集和验证集(70:30 比例)
# x_train, x_val, y_train, y_val = train_test_split(
#     all_scans, all_labels, test_size=0.3, random_state=42, stratify=all_labels
# )
# #存成pikle

# # 转换为 PyTorch 张量
# x_train = torch.tensor(x_train, dtype=torch.float32)  # 添加通道维度
# y_train = torch.tensor(y_train, dtype=torch.long)
# x_val = torch.tensor(x_val, dtype=torch.float32) # 添加通道维度
# y_val = torch.tensor(y_val, dtype=torch.long)

# # 打印训练和验证集的样本数量
# print(f"Number of samples in train and validation are {x_train.shape[0]} and {x_val.shape[0]}.")

import pickle

# # 保存预处理后的数据和标签
# processed_data = {
#     "x_train": x_train,
#     "y_train": y_train,
#     "x_val": x_val,
#     "y_val": y_val,
# }

# # 保存为 pickle 文件
# with open("processed_ct_data.pkl", "wb") as f:
#     pickle.dump(processed_data, f)
# print("数据已保存为 processed_ct_data.pkl")

# 加载 pickle 文件
with open("processed_ct_data.pkl", "rb") as f:
    processed_data = pickle.load(f)

x_train = processed_data["x_train"]
y_train = processed_data["y_train"]
x_val = processed_data["x_val"]
y_val = processed_data["y_val"]

# 打印加载的数据形状
print(f"Loaded x_train: {x_train.shape}, y_train: {y_train.shape}")
print(f"Loaded x_val: {x_val.shape}, y_val: {y_val.shape}")

# 定义批大小
batch_size = 2

# 创建训练数据集并加载
train_dataset = CTDataset(x_train, y_train, train_preprocessing)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# 创建验证数据集并加载
validation_dataset = CTDataset(x_val, y_val, validation_preprocessing)
validation_loader = DataLoader(validation_dataset, batch_size=batch_size, shuffle=False)

# 测试数据加载器
for batch_idx, (volumes, labels) in enumerate(train_loader):
    print(f"Batch {batch_idx + 1}:")
    print(f"Volumes shape: {volumes.shape}")  # (batch_size, C, D, H, W)
    print(f"Labels: {labels}")
    break

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ExponentialLR
from tqdm import tqdm

# 定义超参数
initial_learning_rate = 0.0001
epochs = 100
patience = 15  # 早停的容忍轮数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定义损失函数和优化器
model = CNN3D().to(device)  # 将模型移动到 GPU(如果可用)
criterion = nn.BCELoss()  # 二分类交叉熵损失
optimizer = optim.Adam(model.parameters(), lr=initial_learning_rate)

# 定义学习率调度器(指数衰减)
scheduler = ExponentialLR(optimizer, gamma=0.96)  # decay_rate=0.96




# 训练主循环
best_val_acc = 0.0
early_stop_counter = 0

for epoch in tqdm(range(epochs),desc="Epochs", leave=True):
    print(f"\nEpoch {epoch + 1}/{epochs}")

    # 训练
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")

    # 验证
    val_loss, val_acc = validate_one_epoch(model, validation_loader, criterion, device)
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_acc:.4f}")

    # 更新学习率
    scheduler.step()

    # 模型检查点
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        early_stop_counter = 0

        torch.save(model.state_dict(), "3d_image_classification.pth")
        print("Saved best model!")
    # 早停逻辑
    else:
        early_stop_counter += 1
       
    if early_stop_counter >= patience:
        print("Early stopping triggered.")
        break

print("最高准确率{}".format(best_val_acc))

Editor is loading...
Leave a Comment