Untitled
unknown
plain_text
3 years ago
4.2 kB
11
Indexable
def random_samples(total_folds, random_seed=None):
Final_Samples = json.load(open('Final_Samples.json', 'r'))
# print(len(Final_Samples))
# print(Final_Samples[:2])
positive_samples = Final_Samples[:len_positive_samples]
negative_samples = Final_Samples[len_positive_samples:]
# random_seed = None
if random_seed is not None:
random.seed(random_seed * 2)
random.shuffle(positive_samples)
random.shuffle(negative_samples)
Final_Samples = positive_samples[:500] + negative_samples[:500]
# print(len([x[1] for x in Final_Samples if x[1] == 1]), len([x[1] for x in Final_Samples if x[1] == 0]))
# Final_Samples = np.array(Final_Samples)
# Final_Samples.shape
# # Final_Samples.reshape(10, -1, 2).shape
# Final_Samples = json.load(open('Final_Samples.json', 'r'))
# positive_samples = Final_Samples[:654]
# negative_samples = Final_Samples[654:]
# if random_seed is not None:
# random.seed(random_seed * 2)
# random.shuffle(positive_samples)
# random.shuffle(negative_samples)
# Final_Samples = positive_samples[:500] + negative_samples[:500]
if random_seed is not None:
random.seed(random_seed)
random.shuffle(Final_Samples)
Final_Samples = np.array(Final_Samples)
N_splits = Final_Samples.reshape(total_folds, -1, 2)
return N_splits
def generate_datasets(N_splits, fold_num, random_seed):
test_samples = N_splits[fold_num:fold_num+1].reshape([-1, 2])
train_samples = np.concatenate([N_splits[0:fold_num],N_splits[fold_num+1:]], 0).reshape([-1, 2]).tolist()
if random_seed is not None:
random.seed(random_seed * 3)
random.shuffle(train_samples)
train_samples = np.array(train_samples)
split_pos = int(train_samples.shape[0] * 1.)
#split_pos = int(train_samples.shape[0] * .8)
# print(train_samples.shape, split_pos, train_samples.shape[0])
train_samples, val_samples = train_samples[:split_pos], train_samples[split_pos:]
train_set = dataSet(Final_Samples=train_samples,
feature_matrix=FEATURE_MATRIX,
usable_samples_ADNI=usable_samples_ADNI)
val_set = dataSet(Final_Samples=val_samples,
feature_matrix=FEATURE_MATRIX,
usable_samples_ADNI=usable_samples_ADNI)
test_set = dataSet(Final_Samples=test_samples,
feature_matrix=FEATURE_MATRIX,
usable_samples_ADNI=usable_samples_ADNI)
mean, std = train_set.get_mean_std()
train_set.update_prs_features(mean, std)
val_set.update_prs_features(mean, std)
test_set.update_prs_features(mean, std)
return train_set, val_set, test_set
def generate_loader(train_set, val_set, test_set, num_workers):
train_batch_size = train_set.__len__()
val_batch_size = val_set.__len__()
test_batch_size = test_set.__len__()
train_loader = torch.utils.data.DataLoader(train_set,
batch_size=train_batch_size,
shuffle=True,
pin_memory=(torch.cuda.is_available()),
num_workers=num_workers)
val_loader = torch.utils.data.DataLoader(val_set,
batch_size=val_batch_size,
shuffle=False,
pin_memory=(torch.cuda.is_available()),
num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_set,
batch_size=test_batch_size,
shuffle=False,
pin_memory=(torch.cuda.is_available()),
num_workers=num_workers)
return train_loader, val_loader, test_loader
train_set, val_set, test_set = generate_datasets(N_splits=random_samples(total_folds=10, random_seed=0), fold_num=0, random_seed=0)
val_set.feature_matrix.shapeEditor is loading...