augmentation
unknown
python
a year ago
2.1 kB
5
Indexable
import random
import torch
import numpy as np
from scipy import ndimage
import torch
from torch.utils.data import Dataset, DataLoader
def rotate(volume):
"""随机旋转体数据一定角度"""
# 定义一些旋转角度
angles = [-20, -10, -5, 5, 10, 20]
# 随机选择一个角度
angle = random.choice(angles)
# 旋转体数据
volume = ndimage.rotate(volume, angle, reshape=False)
# 保证值在 [0, 1] 范围内
volume[volume < 0] = 0
volume[volume > 1] = 1
return volume
def train_preprocessing(volume, label):
"""训练数据预处理,包括旋转和添加通道维度"""
# 随机旋转体数据
volume = rotate(volume)
# 添加通道维度 (PyTorch 的 3D 卷积要求输入为 (C, D, H, W))
volume = np.expand_dims(volume, axis=0) # 添加通道维度
# 转换为 PyTorch 张量
volume = torch.tensor(volume, dtype=torch.float32).clone()
label = torch.tensor(label, dtype=torch.long).clone()
return volume, label
def validation_preprocessing(volume, label):
"""验证数据预处理,仅添加通道维度"""
# 添加通道维度
volume = np.expand_dims(volume, axis=0) # 添加通道维度
# 转换为 PyTorch 张量
volume = torch.tensor(volume, dtype=torch.float32)
label = torch.tensor(label, dtype=torch.long)
return volume, label
class CTDataset(Dataset):
"""自定义数据集,用于加载 CT 扫描数据"""
def __init__(self, data, labels, preprocess_fn):
"""
Args:
data (numpy array): CT 扫描数据
labels (numpy array): 标签
preprocess_fn (callable): 预处理函数
"""
self.data = data
self.labels = labels
self.preprocess_fn = preprocess_fn
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
volume = self.data[idx]
label = self.labels[idx]
# 应用预处理函数
volume, label = self.preprocess_fn(volume, label)
return volume, label
Editor is loading...
Leave a Comment