Untitled

mail@pastecode.io avatar
unknown
plain_text
a year ago
7.3 kB
3
Indexable
Never
########################################################################################
#
# This file implements the components to construct a datapipe for the tiny-voxceleb
# dataset.
#
# Author(s): Nik Vaessen
########################################################################################

import collections
import functools
import json
import pathlib

from typing import Tuple, Dict, List

import torch as t
import torch.utils.data
import torchaudio

from torch.utils.data.datapipes.utils.common import StreamWrapper
from torchdata.datapipes.iter import (
    FileLister,
    Shuffler,
    Header,
    ShardingFilter,
    FileOpener,
    Mapper,
    TarArchiveLoader,
    WebDataset,
    IterDataPipe,
    Batcher,
)


########################################################################################
# helper methods for decoding binary streams from files to useful python objects


def decode_wav(value: StreamWrapper) -> t.Tensor:
    assert isinstance(value, StreamWrapper)

    value, sample_rate = torchaudio.load(value)
    assert sample_rate == 16_000

    # make sure that audio has 1 dimension
    value = torch.squeeze(value)

    return value


def decode_json(value: StreamWrapper) -> Dict:
    assert isinstance(value, StreamWrapper)

    return json.load(value)


def decode(element: Tuple[str, StreamWrapper]):
    assert isinstance(element, tuple) and len(element) == 2
    key, value = element

    assert isinstance(key, str)
    assert isinstance(value, StreamWrapper)

    if key.endswith(".wav"):
        value = decode_wav(value)

    if key.endswith(".json"):
        value = decode_json(value)

    return key, value


########################################################################################
# default pipeline loading data from tar files into a tuple (sample_id, x, y)

Sample = collections.namedtuple("Sample", ["sample_id", "x", "y"])


def construct_sample_datapipe(
    shard_folder: pathlib.Path, # Default if not specified: current working directory
    num_workers: int,
    buffer_size: int = 0,
    shuffle_shards_on_epoch: bool = False,
) -> IterDataPipe[Sample]:
    # list all shards
    shard_list = [str(f) for f in shard_folder.glob("shard-*.tar")]

    if len(shard_list) == 0:
        raise ValueError(f"unable to find any shards in {shard_folder}")

    # stream of strings representing each shard
    dp = FileLister(shard_list)

    # shuffle the stream so order of shards in epoch differs
    if shuffle_shards_on_epoch:
        dp = Shuffler(dp, buffer_size=len(shard_list))

    # make sure each worker receives the same number of shards
    if num_workers > 0:
        if len(shard_list) < num_workers:
            raise ValueError(f"{num_workers=} cannot be smaller than {len(shard_list)=}")

        dp = Header(dp, limit=len(shard_list) // num_workers)

        # each worker only sees 1/n elements
        dp = ShardingFilter(dp)

    # map strings of paths to file handles
    dp = FileOpener(dp, mode="b")

    # expand each file handle to a stream of all files in the tar
    dp = TarArchiveLoader(dp, mode="r")

    # decode each file in the tar to the expected python dataformat
    dp = Mapper(dp, decode)

    # each file in the tar is expected to have the format `{key}.{ext}
    # this groups all files with the same key into one dictionary
    dp = WebDataset(dp)

    # transform the dictionaries into tuple (sample_id, x, y)
    dp = Mapper(dp, map_dict_to_tuple)

    # buffer tuples to increase variability
    if buffer_size > 0:
        dp = Shuffler(dp, buffer_size=buffer_size)

    return dp


def map_dict_to_tuple(x: Dict) -> Sample:
    sample_id = x[".json"]["sample_id"]
    wav = x[".wav"]

    class_idx = x[".json"]["class_idx"]
    if class_idx is None:
        gt = None
    else:
        gt = t.tensor(x[".json"]["class_idx"], dtype=t.int64)

    return Sample(sample_id, wav, gt)


########################################################################################
# useful transformation on a stream of sample objects


def _chunk_sample(sample: Sample, num_frames: int):
    sample_id, x, y = sample

    sample_length = x.shape[0]
    start_idx = t.randint(low=0, high=sample_length - num_frames - 1, size=())
    end_idx = start_idx + num_frames

    assert len(x.shape) == 1  # before e.g. mfcc transformation
    x = x[start_idx:end_idx]

    return Sample(sample_id, x, y)

"""
Extract a fixed length segment from each audio sample in the pipeline - num_frames: length of the segment
"""
def pipe_chunk_sample(
    dp: IterDataPipe[Sample], num_frames: int
) -> IterDataPipe[Sample]:
    return Mapper(dp, functools.partial(_chunk_sample, num_frames=num_frames))


def _mfcc_sample(sample: Sample, mfcc: torchaudio.transforms.MFCC):
    sample_id, x, y = sample

    # we go from shape [num_frames] to [num_mel_coeff, num_frames/window_size]
    x = mfcc(x)

    return Sample(sample_id, x, y)


def pipe_mfcc(dp: IterDataPipe[Sample], n_mfcc: int) -> IterDataPipe[Sample]:
    mfcc = torchaudio.transforms.MFCC(n_mfcc=n_mfcc)

    return Mapper(dp, functools.partial(_mfcc_sample, mfcc=mfcc))

"""
Adding noise to the audio.
"""
def _add_noise_sample(sample: Sample, noise_factor: float):
    sample_id, x, y = sample
    
    noise = t.randn_like(x) * noise_factor
    x = x + noise
    
    return Sample(sample_id, x, y)

"""
Function to implement different data agumentation techniques:
- Adding noise
"""
def additional_augmentation(dp: IterDataPipe[Sample], noise_factor: float) -> IterDataPipe[Sample]:
    dp = Mapper(dp, functools.partial(_add_noise_sample, noise_factor=noise_factor))

    return dp


def _batch_samples(samples: List[Sample]):
    assert len(samples) > 0

    shapes = {s.x.shape for s in samples}
    assert len(shapes) == 1  # all samples have same shape

    x = torch.utils.data.default_collate([s.x for s in samples])

    if samples[0].y is not None:
        y = torch.utils.data.default_collate([s.y for s in samples])
    else:
        y = None

    sample_id = [s.sample_id for s in samples]

    return Sample(sample_id, x, y)


def pipe_batch_samples(
    dp: IterDataPipe[Sample], batch_size: int, drop_last: bool = False
) -> IterDataPipe[Sample]:
    return Batcher(dp, batch_size, drop_last=drop_last, wrapper_class=_batch_samples)


########################################################################################
# useful for debugging a datapipe


def _print_sample(dp):
    for sample in dp:
        sample_id, x, y = sample
        print(f"{sample_id=}\n")

        print(f"{x.shape=}")
        print(f"{x.dtype=}\n")

        print(y)
        print(f"{y.shape=}")
        print(f"{y.dtype=}\n")
        break


def _debug():
    shard_path = pathlib.Path(
        "/home/nvaessen/phd/repo/tiny-voxceleb-skeleton/data/shards/train"
    )

    print("### construct_sample_datapipe ###")
    dp = construct_sample_datapipe(shard_path, num_workers=0)
    _print_sample(dp)

    print("### pipe_chunk_sample ###")
    dp = pipe_chunk_sample(dp, 16_000 * 3)  # 3 seconds
    _print_sample(dp)

    print("### pipe_mfcc ###")
    dp = pipe_mfcc(dp)
    _print_sample(dp)

    print("### pipe_batch_samples ###")
    dp = pipe_batch_samples(dp, 8)
    _print_sample(dp)


if __name__ == "__main__":
    _debug()