Untitled
unknown
plain_text
3 years ago
7.3 kB
12
Indexable
########################################################################################
#
# This file implements the components to construct a datapipe for the tiny-voxceleb
# dataset.
#
# Author(s): Nik Vaessen
########################################################################################
import collections
import functools
import json
import pathlib
from typing import Tuple, Dict, List
import torch as t
import torch.utils.data
import torchaudio
from torch.utils.data.datapipes.utils.common import StreamWrapper
from torchdata.datapipes.iter import (
FileLister,
Shuffler,
Header,
ShardingFilter,
FileOpener,
Mapper,
TarArchiveLoader,
WebDataset,
IterDataPipe,
Batcher,
)
########################################################################################
# helper methods for decoding binary streams from files to useful python objects
def decode_wav(value: StreamWrapper) -> t.Tensor:
assert isinstance(value, StreamWrapper)
value, sample_rate = torchaudio.load(value)
assert sample_rate == 16_000
# make sure that audio has 1 dimension
value = torch.squeeze(value)
return value
def decode_json(value: StreamWrapper) -> Dict:
assert isinstance(value, StreamWrapper)
return json.load(value)
def decode(element: Tuple[str, StreamWrapper]):
assert isinstance(element, tuple) and len(element) == 2
key, value = element
assert isinstance(key, str)
assert isinstance(value, StreamWrapper)
if key.endswith(".wav"):
value = decode_wav(value)
if key.endswith(".json"):
value = decode_json(value)
return key, value
########################################################################################
# default pipeline loading data from tar files into a tuple (sample_id, x, y)
Sample = collections.namedtuple("Sample", ["sample_id", "x", "y"])
def construct_sample_datapipe(
shard_folder: pathlib.Path, # Default if not specified: current working directory
num_workers: int,
buffer_size: int = 0,
shuffle_shards_on_epoch: bool = False,
) -> IterDataPipe[Sample]:
# list all shards
shard_list = [str(f) for f in shard_folder.glob("shard-*.tar")]
if len(shard_list) == 0:
raise ValueError(f"unable to find any shards in {shard_folder}")
# stream of strings representing each shard
dp = FileLister(shard_list)
# shuffle the stream so order of shards in epoch differs
if shuffle_shards_on_epoch:
dp = Shuffler(dp, buffer_size=len(shard_list))
# make sure each worker receives the same number of shards
if num_workers > 0:
if len(shard_list) < num_workers:
raise ValueError(f"{num_workers=} cannot be smaller than {len(shard_list)=}")
dp = Header(dp, limit=len(shard_list) // num_workers)
# each worker only sees 1/n elements
dp = ShardingFilter(dp)
# map strings of paths to file handles
dp = FileOpener(dp, mode="b")
# expand each file handle to a stream of all files in the tar
dp = TarArchiveLoader(dp, mode="r")
# decode each file in the tar to the expected python dataformat
dp = Mapper(dp, decode)
# each file in the tar is expected to have the format `{key}.{ext}
# this groups all files with the same key into one dictionary
dp = WebDataset(dp)
# transform the dictionaries into tuple (sample_id, x, y)
dp = Mapper(dp, map_dict_to_tuple)
# buffer tuples to increase variability
if buffer_size > 0:
dp = Shuffler(dp, buffer_size=buffer_size)
return dp
def map_dict_to_tuple(x: Dict) -> Sample:
sample_id = x[".json"]["sample_id"]
wav = x[".wav"]
class_idx = x[".json"]["class_idx"]
if class_idx is None:
gt = None
else:
gt = t.tensor(x[".json"]["class_idx"], dtype=t.int64)
return Sample(sample_id, wav, gt)
########################################################################################
# useful transformation on a stream of sample objects
def _chunk_sample(sample: Sample, num_frames: int):
sample_id, x, y = sample
sample_length = x.shape[0]
start_idx = t.randint(low=0, high=sample_length - num_frames - 1, size=())
end_idx = start_idx + num_frames
assert len(x.shape) == 1 # before e.g. mfcc transformation
x = x[start_idx:end_idx]
return Sample(sample_id, x, y)
"""
Extract a fixed length segment from each audio sample in the pipeline - num_frames: length of the segment
"""
def pipe_chunk_sample(
dp: IterDataPipe[Sample], num_frames: int
) -> IterDataPipe[Sample]:
return Mapper(dp, functools.partial(_chunk_sample, num_frames=num_frames))
def _mfcc_sample(sample: Sample, mfcc: torchaudio.transforms.MFCC):
sample_id, x, y = sample
# we go from shape [num_frames] to [num_mel_coeff, num_frames/window_size]
x = mfcc(x)
return Sample(sample_id, x, y)
def pipe_mfcc(dp: IterDataPipe[Sample], n_mfcc: int) -> IterDataPipe[Sample]:
mfcc = torchaudio.transforms.MFCC(n_mfcc=n_mfcc)
return Mapper(dp, functools.partial(_mfcc_sample, mfcc=mfcc))
"""
Adding noise to the audio.
"""
def _add_noise_sample(sample: Sample, noise_factor: float):
sample_id, x, y = sample
noise = t.randn_like(x) * noise_factor
x = x + noise
return Sample(sample_id, x, y)
"""
Function to implement different data agumentation techniques:
- Adding noise
"""
def additional_augmentation(dp: IterDataPipe[Sample], noise_factor: float) -> IterDataPipe[Sample]:
dp = Mapper(dp, functools.partial(_add_noise_sample, noise_factor=noise_factor))
return dp
def _batch_samples(samples: List[Sample]):
assert len(samples) > 0
shapes = {s.x.shape for s in samples}
assert len(shapes) == 1 # all samples have same shape
x = torch.utils.data.default_collate([s.x for s in samples])
if samples[0].y is not None:
y = torch.utils.data.default_collate([s.y for s in samples])
else:
y = None
sample_id = [s.sample_id for s in samples]
return Sample(sample_id, x, y)
def pipe_batch_samples(
dp: IterDataPipe[Sample], batch_size: int, drop_last: bool = False
) -> IterDataPipe[Sample]:
return Batcher(dp, batch_size, drop_last=drop_last, wrapper_class=_batch_samples)
########################################################################################
# useful for debugging a datapipe
def _print_sample(dp):
for sample in dp:
sample_id, x, y = sample
print(f"{sample_id=}\n")
print(f"{x.shape=}")
print(f"{x.dtype=}\n")
print(y)
print(f"{y.shape=}")
print(f"{y.dtype=}\n")
break
def _debug():
shard_path = pathlib.Path(
"/home/nvaessen/phd/repo/tiny-voxceleb-skeleton/data/shards/train"
)
print("### construct_sample_datapipe ###")
dp = construct_sample_datapipe(shard_path, num_workers=0)
_print_sample(dp)
print("### pipe_chunk_sample ###")
dp = pipe_chunk_sample(dp, 16_000 * 3) # 3 seconds
_print_sample(dp)
print("### pipe_mfcc ###")
dp = pipe_mfcc(dp)
_print_sample(dp)
print("### pipe_batch_samples ###")
dp = pipe_batch_samples(dp, 8)
_print_sample(dp)
if __name__ == "__main__":
_debug()
Editor is loading...