import ast
import data
import os
import time
import soundfile as sf
import torch
import numpy as np
from functools import reduce
from nltk.corpus import cmudict
from nltk.tokenize import NLTKWordTokenizer
from copy import deepcopy
from tqdm import tqdm
from torchinfo import summary
from torch.utils.tensorboard import SummaryWriter
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from utils import audio_augment, get_audio_duration
import models
import pandas as pd
import pickle
from kmeans_pytorch import kmeans
import warnings
import gzip
# Filter and suppress all warnings
warnings.filterwarnings("ignore")
config = data.read_config("experiments/no_unfreezing.cfg")
train_dataset, valid_dataset, test_dataset = data.get_SLU_datasets(config)
print('config load ok')
device = 'cpu'
pwd = os.getcwd()
wav_path = os.path.join(pwd, 'SLURP/slurp_real/')
# folder_path = os.path.join(pwd, 'models/SLURP/new-slurp_multicache_audio_bucket')
# folder_path = os.path.join(pwd, 'models/SLURP/curated-slurp-headset')
# folder_path = os.path.join(pwd, 'models/SLURP/curated-slurp-without-headset')
folder_path = os.path.join(pwd, 'models/SLURP/curated-slurp-headset-base')
NUM_CLUSTERS = 70
dist = 'euclidean'
tol = 1e-4
with open('phoneme_list.txt', 'r') as file:
id2phoneme = ast.literal_eval(file.read())
phoneme2id = {v: k for k, v in id2phoneme.items()}
d = cmudict.dict()
tknz = NLTKWordTokenizer()
def train(model, optim, df, ctc_loss, bucket):
num_nan_train = 0
transcripts = np.unique(df['sentence'])
training_idxs = set()
# the following three are used for evaluation
transcript_list = []
phoneme_list = []
intent_list = []
cluster_ids = []
cluster_centers = []
# training
for tscpt_idx, transcript in tqdm(enumerate(transcripts), total=len(transcripts), leave=False):
optim.zero_grad()
# remove ending punctuation from the transcript
phoneme_seq = reduce(lambda x, y: x + ['sp'] + y,
[d[tk][0] if tk in d else [] for tk in tknz.tokenize(transcript.lower())])
transcript_list.append(transcript)
# random choose one file with `transcription`
rows = df[df['sentence'] == transcript]
# sample only one audio file for each distinct transcription
row = rows.iloc[np.random.randint(len(rows))]
intent_list.append(row['intent'])
# add the index to training set, won't use in eval below
training_idxs.add(row[0])
# load the audio file
wav = wav_path + row['recording_path']
x, _ = sf.read(wav)
x_aug = torch.tensor(audio_augment(x), dtype=torch.float, device=device)
# ----------------- kmeans cluster -----------------
feature = model.pretrained_model.compute_cnn_features(x_aug)
cluster_id, cluster_center = kmeans(X=feature.reshape(-1, feature.shape[-1]), num_clusters=NUM_CLUSTERS,
distance=dist, tol=tol, device=device)
# save the cluster center
intention_label = []
prev = None
# collapses the cluster predictions
for l in cluster_id.view(feature.shape[0], -1)[0]:
if prev is None or prev != l:
intention_label.append(l.item())
prev = l
cluster_ids.append(torch.tensor(intention_label, dtype=torch.long, device=device))
cluster_centers.append(cluster_center)
# ----------------- phoneme ctc -------------------
# phoneme_seq, weight = get_token_and_weight(transcript.lower())
phoneme_seq = reduce(lambda x, y: x + ['sp'] + y,
[d[tk][0] if tk in d else [] for tk in tknz.tokenize(transcript.lower())])
phoneme_label = torch.tensor(
[phoneme2id[ph[:-1]] if ph[-1].isdigit() else phoneme2id[ph] for ph in phoneme_seq],
dtype=torch.long, device=device)
phoneme_list.append(phoneme_label)
phoneme_label = phoneme_label.repeat(x_aug.shape[0], 1)
phoneme_pred = model.pretrained_model.compute_phonemes(x_aug)
pred_lengths = torch.full(size=(x_aug.shape[0],), fill_value=phoneme_pred.shape[0], dtype=torch.long)
label_lengths = torch.full(size=(x_aug.shape[0],), fill_value=phoneme_label.shape[-1], dtype=torch.long)
loss = ctc_loss(phoneme_pred, phoneme_label, pred_lengths, label_lengths)
# FIXME implement better fix for nan loss
if torch.isnan(loss).any():
num_nan_train = num_nan_train + 1
print('nan training on speaker: %s' % user_id)
optim.zero_grad()
loss.backward()
optim.step()
if num_nan_train:
print('nan in train happens %d times' % num_nan_train)
print('train %d test %d' % (len(training_idxs), len(df) - len(training_idxs)))
# remove unnecessary layers
# del model.pretrained_model.word_layers
# del model.pretrained_model.word_linear
# del model.intent_layers
filename = f'slurp_curated_headset_base_multicache_{user_id}_audio_bucket_{bucket}'
file_path = os.path.join(folder_path, filename + '.pth')
torch.save(model.state_dict(), file_path)
metadata = {
'df': df,
'bucket': bucket,
'speakerId': user_id,
'transcript_list': transcript_list,
'phoneme_list': phoneme_list,
'intent_list': intent_list,
'training_idxs': training_idxs,
'cluster_ids': cluster_ids,
'cluster_centers': cluster_centers,
}
with gzip.open(os.path.join(folder_path, filename + '.pkl.gz'), 'wb') as f:
pickle.dump(metadata, f)
# slurp_df = pd.read_csv(os.path.join(pwd, 'slurp_mini_FE_MO_ME_FO_UNK.csv'))
# slurp_df = pd.read_csv(os.path.join(pwd, 'SLURP/csv/slurp_new_df.csv'))
# slurp_df = pd.read_csv(os.path.join(pwd, 'SLURP/csv/slurp_without_headset.csv'))
slurp_df = pd.read_csv(os.path.join(pwd, 'SLURP/csv/slurp_headset.csv'))
slurp_df = deepcopy(slurp_df)
speakers = np.unique(slurp_df['user_id'])
# speakers = ['MO-433', 'UNK-326', 'FO-232']
# speakers = ['FE-141']
cumulative_sample, cumulative_correct, cumulative_hit, cumulative_hit_correct = 0, 0, 0, 0
for _, user_id in tqdm(enumerate(speakers), total=len(speakers)):
print('training for speaker %s' % user_id)
df = slurp_df[slurp_df['user_id'] == user_id]
df1 = pd.DataFrame(columns=slurp_df.columns)
df2 = pd.DataFrame(columns=slurp_df.columns)
df3 = pd.DataFrame(columns=slurp_df.columns)
for _, row in df.iterrows():
wav = wav_path + row['recording_path']
if 0 <= (get_audio_duration(wav)) <= 2.7:
df1 = df1._append(row)
elif 2.7 < (get_audio_duration(wav)) <= 4:
df2 = df2._append(row)
else:
df3 = df3._append(row)
print('%d' % (len(df1) + len(df2) + len(df3)))
# pretrained_file = "slurp-pretrained.pth"
# pretrained_path = os.path.join(pwd + "/models/SLURP/", pretrained_file)
pretrained_path = "experiments/no_unfreezing/training/model_state.pth"
model1 = models.Model(config).eval()
optim1 = torch.optim.Adam(model1.parameters(), lr=1e-3)
model1.load_state_dict(
torch.load(pretrained_path, map_location=device)) # load trained model
ctc_loss1 = torch.nn.CTCLoss()
train(model1, optim1, df1, ctc_loss1, bucket=1)
model2 = models.Model(config).eval()
optim2 = torch.optim.Adam(model2.parameters(), lr=1e-3)
model2.load_state_dict(
torch.load(pretrained_path, map_location=device)) # load trained model
ctc_loss2 = torch.nn.CTCLoss()
train(model2, optim2, df2, ctc_loss2, bucket=2)
model3 = models.Model(config).eval()
optim3 = torch.optim.Adam(model3.parameters(), lr=1e-3)
model3.load_state_dict(
torch.load(pretrained_path, map_location=device)) # load trained model
ctc_loss3 = torch.nn.CTCLoss()
train(model3, optim3, df3, ctc_loss3, bucket=3)