Source code for sbss.common.datasets.wav_hdf5_dataset

# Copyright (C) 2025 National Institute of Advanced Industrial Science and Technology (AIST)
# SPDX-License-Identifier: MIT

from pathlib import Path

import numpy as np

import torch

from aiaccel.torch.datasets import CachedDataset, HDF5Dataset


[docs] class WavHDF5Dataset(torch.utils.data.Dataset): """Dataset class for loading and preprocessing multi-channel audio data stored in HDF5 format. This dataset wraps an ``HDF5Dataset`` using a caching layer for faster access. It supports random microphone permutation for data augmentation and optional random cropping of the waveform based on the specified duration and sampling rate. Args: dataset_path (Path | str): Path to the HDF5 dataset directory or file. duration (int | None, optional): Target duration of audio clips in seconds. If None, the full audio is used. sr (int | None, optional): Sampling rate used to calculate the crop length when ``duration`` is specified. randperm_mic (bool, optional): Whether to randomly permute microphone channels. Defaults to True. grp_list (Path | str | list[str] | None, optional): List or path specifying group names to load from the HDF5 dataset. If None, all groups are used. Returns: torch.utils.data.Dataset: A PyTorch dataset yielding preprocessed multi-channel waveforms. """
[docs] def __init__( self, dataset_path: Path | str, duration: int | None = None, sr: int | None = None, randperm_mic: bool = True, grp_list: Path | str | list[str] | None = None, ) -> None: super().__init__() self._dataset = CachedDataset(HDF5Dataset(dataset_path, grp_list)) self.duration = duration self.sr = sr self.randperm_mic = randperm_mic
def __len__(self) -> int: return len(self._dataset) def __getitem__(self, index: int): item = self._dataset[index] wav = item["wav"] if self.duration is not None: duration = self.sr * self.duration t_start = np.random.randint(0, wav.shape[1] - duration + 1) t_end = t_start + duration wav = wav[:, t_start:t_end] if self.randperm_mic: wav = wav[torch.randperm(wav.shape[0])] return wav