train
unknown
python
a year ago
7.9 kB
1
Indexable
Never
import ast import data import os import time import soundfile as sf import torch import numpy as np from functools import reduce from nltk.corpus import cmudict from nltk.tokenize import NLTKWordTokenizer from copy import deepcopy from tqdm import tqdm from torchinfo import summary from torch.utils.tensorboard import SummaryWriter from torch.nn.utils.rnn import pad_sequence from tqdm import tqdm from utils import audio_augment, get_audio_duration import models import pandas as pd import pickle from kmeans_pytorch import kmeans import warnings import gzip # Filter and suppress all warnings warnings.filterwarnings("ignore") config = data.read_config("experiments/no_unfreezing.cfg") train_dataset, valid_dataset, test_dataset = data.get_SLU_datasets(config) print('config load ok') device = 'cpu' pwd = os.getcwd() wav_path = os.path.join(pwd, 'SLURP/slurp_real/') # folder_path = os.path.join(pwd, 'models/SLURP/new-slurp_multicache_audio_bucket') # folder_path = os.path.join(pwd, 'models/SLURP/curated-slurp-headset') # folder_path = os.path.join(pwd, 'models/SLURP/curated-slurp-without-headset') folder_path = os.path.join(pwd, 'models/SLURP/curated-slurp-headset-base') NUM_CLUSTERS = 70 dist = 'euclidean' tol = 1e-4 with open('phoneme_list.txt', 'r') as file: id2phoneme = ast.literal_eval(file.read()) phoneme2id = {v: k for k, v in id2phoneme.items()} d = cmudict.dict() tknz = NLTKWordTokenizer() def train(model, optim, df, ctc_loss, bucket): num_nan_train = 0 transcripts = np.unique(df['sentence']) training_idxs = set() # the following three are used for evaluation transcript_list = [] phoneme_list = [] intent_list = [] cluster_ids = [] cluster_centers = [] # training for tscpt_idx, transcript in tqdm(enumerate(transcripts), total=len(transcripts), leave=False): optim.zero_grad() # remove ending punctuation from the transcript phoneme_seq = reduce(lambda x, y: x + ['sp'] + y, [d[tk][0] if tk in d else [] for tk in tknz.tokenize(transcript.lower())]) transcript_list.append(transcript) # random choose one file with `transcription` rows = df[df['sentence'] == transcript] # sample only one audio file for each distinct transcription row = rows.iloc[np.random.randint(len(rows))] intent_list.append(row['intent']) # add the index to training set, won't use in eval below training_idxs.add(row[0]) # load the audio file wav = wav_path + row['recording_path'] x, _ = sf.read(wav) x_aug = torch.tensor(audio_augment(x), dtype=torch.float, device=device) # ----------------- kmeans cluster ----------------- feature = model.pretrained_model.compute_cnn_features(x_aug) cluster_id, cluster_center = kmeans(X=feature.reshape(-1, feature.shape[-1]), num_clusters=NUM_CLUSTERS, distance=dist, tol=tol, device=device) # save the cluster center intention_label = [] prev = None # collapses the cluster predictions for l in cluster_id.view(feature.shape[0], -1)[0]: if prev is None or prev != l: intention_label.append(l.item()) prev = l cluster_ids.append(torch.tensor(intention_label, dtype=torch.long, device=device)) cluster_centers.append(cluster_center) # ----------------- phoneme ctc ------------------- # phoneme_seq, weight = get_token_and_weight(transcript.lower()) phoneme_seq = reduce(lambda x, y: x + ['sp'] + y, [d[tk][0] if tk in d else [] for tk in tknz.tokenize(transcript.lower())]) phoneme_label = torch.tensor( [phoneme2id[ph[:-1]] if ph[-1].isdigit() else phoneme2id[ph] for ph in phoneme_seq], dtype=torch.long, device=device) phoneme_list.append(phoneme_label) phoneme_label = phoneme_label.repeat(x_aug.shape[0], 1) phoneme_pred = model.pretrained_model.compute_phonemes(x_aug) pred_lengths = torch.full(size=(x_aug.shape[0],), fill_value=phoneme_pred.shape[0], dtype=torch.long) label_lengths = torch.full(size=(x_aug.shape[0],), fill_value=phoneme_label.shape[-1], dtype=torch.long) loss = ctc_loss(phoneme_pred, phoneme_label, pred_lengths, label_lengths) # FIXME implement better fix for nan loss if torch.isnan(loss).any(): num_nan_train = num_nan_train + 1 print('nan training on speaker: %s' % user_id) optim.zero_grad() loss.backward() optim.step() if num_nan_train: print('nan in train happens %d times' % num_nan_train) print('train %d test %d' % (len(training_idxs), len(df) - len(training_idxs))) # remove unnecessary layers # del model.pretrained_model.word_layers # del model.pretrained_model.word_linear # del model.intent_layers filename = f'slurp_curated_headset_base_multicache_{user_id}_audio_bucket_{bucket}' file_path = os.path.join(folder_path, filename + '.pth') torch.save(model.state_dict(), file_path) metadata = { 'df': df, 'bucket': bucket, 'speakerId': user_id, 'transcript_list': transcript_list, 'phoneme_list': phoneme_list, 'intent_list': intent_list, 'training_idxs': training_idxs, 'cluster_ids': cluster_ids, 'cluster_centers': cluster_centers, } with gzip.open(os.path.join(folder_path, filename + '.pkl.gz'), 'wb') as f: pickle.dump(metadata, f) # slurp_df = pd.read_csv(os.path.join(pwd, 'slurp_mini_FE_MO_ME_FO_UNK.csv')) # slurp_df = pd.read_csv(os.path.join(pwd, 'SLURP/csv/slurp_new_df.csv')) # slurp_df = pd.read_csv(os.path.join(pwd, 'SLURP/csv/slurp_without_headset.csv')) slurp_df = pd.read_csv(os.path.join(pwd, 'SLURP/csv/slurp_headset.csv')) slurp_df = deepcopy(slurp_df) speakers = np.unique(slurp_df['user_id']) # speakers = ['MO-433', 'UNK-326', 'FO-232'] # speakers = ['FE-141'] cumulative_sample, cumulative_correct, cumulative_hit, cumulative_hit_correct = 0, 0, 0, 0 for _, user_id in tqdm(enumerate(speakers), total=len(speakers)): print('training for speaker %s' % user_id) df = slurp_df[slurp_df['user_id'] == user_id] df1 = pd.DataFrame(columns=slurp_df.columns) df2 = pd.DataFrame(columns=slurp_df.columns) df3 = pd.DataFrame(columns=slurp_df.columns) for _, row in df.iterrows(): wav = wav_path + row['recording_path'] if 0 <= (get_audio_duration(wav)) <= 2.7: df1 = df1._append(row) elif 2.7 < (get_audio_duration(wav)) <= 4: df2 = df2._append(row) else: df3 = df3._append(row) print('%d' % (len(df1) + len(df2) + len(df3))) # pretrained_file = "slurp-pretrained.pth" # pretrained_path = os.path.join(pwd + "/models/SLURP/", pretrained_file) pretrained_path = "experiments/no_unfreezing/training/model_state.pth" model1 = models.Model(config).eval() optim1 = torch.optim.Adam(model1.parameters(), lr=1e-3) model1.load_state_dict( torch.load(pretrained_path, map_location=device)) # load trained model ctc_loss1 = torch.nn.CTCLoss() train(model1, optim1, df1, ctc_loss1, bucket=1) model2 = models.Model(config).eval() optim2 = torch.optim.Adam(model2.parameters(), lr=1e-3) model2.load_state_dict( torch.load(pretrained_path, map_location=device)) # load trained model ctc_loss2 = torch.nn.CTCLoss() train(model2, optim2, df2, ctc_loss2, bucket=2) model3 = models.Model(config).eval() optim3 = torch.optim.Adam(model3.parameters(), lr=1e-3) model3.load_state_dict( torch.load(pretrained_path, map_location=device)) # load trained model ctc_loss3 = torch.nn.CTCLoss() train(model3, optim3, df3, ctc_loss3, bucket=3)