augmentation
unknown
python
4 months ago
2.1 kB
3
Indexable
import random import torch import numpy as np from scipy import ndimage import torch from torch.utils.data import Dataset, DataLoader def rotate(volume): """随机旋转体数据一定角度""" # 定义一些旋转角度 angles = [-20, -10, -5, 5, 10, 20] # 随机选择一个角度 angle = random.choice(angles) # 旋转体数据 volume = ndimage.rotate(volume, angle, reshape=False) # 保证值在 [0, 1] 范围内 volume[volume < 0] = 0 volume[volume > 1] = 1 return volume def train_preprocessing(volume, label): """训练数据预处理,包括旋转和添加通道维度""" # 随机旋转体数据 volume = rotate(volume) # 添加通道维度 (PyTorch 的 3D 卷积要求输入为 (C, D, H, W)) volume = np.expand_dims(volume, axis=0) # 添加通道维度 # 转换为 PyTorch 张量 volume = torch.tensor(volume, dtype=torch.float32).clone() label = torch.tensor(label, dtype=torch.long).clone() return volume, label def validation_preprocessing(volume, label): """验证数据预处理,仅添加通道维度""" # 添加通道维度 volume = np.expand_dims(volume, axis=0) # 添加通道维度 # 转换为 PyTorch 张量 volume = torch.tensor(volume, dtype=torch.float32) label = torch.tensor(label, dtype=torch.long) return volume, label class CTDataset(Dataset): """自定义数据集,用于加载 CT 扫描数据""" def __init__(self, data, labels, preprocess_fn): """ Args: data (numpy array): CT 扫描数据 labels (numpy array): 标签 preprocess_fn (callable): 预处理函数 """ self.data = data self.labels = labels self.preprocess_fn = preprocess_fn def __len__(self): return len(self.data) def __getitem__(self, idx): volume = self.data[idx] label = self.labels[idx] # 应用预处理函数 volume, label = self.preprocess_fn(volume, label) return volume, label
Editor is loading...
Leave a Comment