augmentation

 avatar
unknown
python
4 months ago
2.1 kB
3
Indexable
import random
import torch
import numpy as np
from scipy import ndimage
import torch
from torch.utils.data import Dataset, DataLoader

def rotate(volume):
    """随机旋转体数据一定角度"""
    # 定义一些旋转角度
    angles = [-20, -10, -5, 5, 10, 20]
    # 随机选择一个角度
    angle = random.choice(angles)
    # 旋转体数据
    volume = ndimage.rotate(volume, angle, reshape=False)
    # 保证值在 [0, 1] 范围内
    volume[volume < 0] = 0
    volume[volume > 1] = 1
    return volume

def train_preprocessing(volume, label):
    """训练数据预处理,包括旋转和添加通道维度"""
    # 随机旋转体数据
    volume = rotate(volume)
    # 添加通道维度 (PyTorch 的 3D 卷积要求输入为 (C, D, H, W))
    volume = np.expand_dims(volume, axis=0)  # 添加通道维度
    # 转换为 PyTorch 张量
    volume = torch.tensor(volume, dtype=torch.float32).clone()
    label = torch.tensor(label, dtype=torch.long).clone()
    return volume, label

def validation_preprocessing(volume, label):
    """验证数据预处理,仅添加通道维度"""
    # 添加通道维度
    volume = np.expand_dims(volume, axis=0)  # 添加通道维度
    # 转换为 PyTorch 张量
    volume = torch.tensor(volume, dtype=torch.float32)
    label = torch.tensor(label, dtype=torch.long)
    return volume, label



class CTDataset(Dataset):
    """自定义数据集,用于加载 CT 扫描数据"""
    def __init__(self, data, labels, preprocess_fn):
        """
        Args:
            data (numpy array): CT 扫描数据
            labels (numpy array): 标签
            preprocess_fn (callable): 预处理函数
        """
        self.data = data
        self.labels = labels
        self.preprocess_fn = preprocess_fn

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        volume = self.data[idx]
        label = self.labels[idx]
        # 应用预处理函数
        volume, label = self.preprocess_fn(volume, label)
        return volume, label

Editor is loading...
Leave a Comment