Source code for sbss.nfca.encoders.dilcnv_encoder

# Copyright (C) 2025 National Institute of Advanced Industrial Science and Technology (AIST)
# SPDX-License-Identifier: MIT

from einops import rearrange

import torch  # noqa
from torch import nn
from torch.distributions import Normal
from torch.nn import functional as fn

from einops.layers.torch import Rearrange


class DepSepConv1d(nn.Sequential):
    def __init__(self, io_ch, mid_ch, ksize, dilation):
        super().__init__(
            nn.Conv1d(io_ch, mid_ch, 1),
            nn.PReLU(mid_ch),
            nn.GroupNorm(mid_ch, mid_ch),
            #
            nn.Conv1d(mid_ch, mid_ch, ksize, padding=(ksize - 1) // 2 * dilation, dilation=dilation, groups=mid_ch),
            nn.PReLU(mid_ch),
            nn.GroupNorm(mid_ch, mid_ch),
            #
            nn.Conv1d(mid_ch, io_ch, 1),
        )

    def forward(self, x):
        return x + super().forward(x)


class DilConvBlock1d(nn.Sequential):
    def __init__(self, io_ch, mid_ch, ksize, n_layers):
        super().__init__(*[DepSepConv1d(io_ch, mid_ch, ksize, 2**ll) for ll in range(n_layers)])


[docs] class DilcnvEncoder(nn.Module): """Encoder module using dilated depthwise separable convolutions for multichannel spectrograms. This encoder extracts feature representations from multichannel complex spectrograms. It first applies batch normalization and logarithmic power compression to the reference channel, then concatenates normalized phase differences between channels. The combined features are processed by a stack of dilated convolutional blocks to produce latent variables, which can optionally be interpreted as Gaussian distributions. Args: n_fft (int): FFT size used to determine the number of frequency bins. n_mic (int): Number of microphone channels in the input. n_src (int): Number of sources to estimate in the latent space. dim_latent (int): Dimensionality of the latent variable. io_ch (int, optional): Number of input/output channels for convolution blocks. Defaults to 256. mid_ch (int, optional): Number of intermediate channels in convolution blocks. Defaults to 512. n_blocks (int, optional): Number of stacked convolutional blocks. Defaults to 4. n_layers (int, optional): Number of dilated convolution layers per block. Defaults to 8. ksize (int, optional): Kernel size for dilated convolutions. Defaults to 3. Returns: torch.distributions.Normal | torch.Tensor: If ``distribution=True``, returns a Normal distribution with mean and scale tensors shaped ``[B, D, N, T]``. Otherwise, returns the mean tensor directly. """
[docs] def __init__( self, n_fft: int, n_mic: int, n_src: int, dim_latent: int, io_ch: int = 256, mid_ch: int = 512, n_blocks: int = 4, n_layers: int = 8, ksize: int = 3, ): super().__init__() n_stft = n_fft // 2 + 1 self.bn0 = nn.BatchNorm1d(n_stft) self.cnv = nn.Sequential( nn.Conv1d((2 * n_mic - 1) * n_stft, io_ch, 1), *[DilConvBlock1d(io_ch, mid_ch, ksize, n_layers) for _ in range(n_blocks)], ) self.cnv_z = nn.Sequential( nn.Conv1d(io_ch, 2 * dim_latent * n_src, 1), Rearrange("b (c d n) t -> c b d n t", c=2, d=dim_latent, n=n_src), )
def forward(self, x: torch.Tensor, distribution: bool = False): """ Parameters ---------- x : (B, F, M, T) Tensor Multichannel spectrogram distribution : bool, optional If True, the returns will be distributions. Defaults is False. """ B, F, M, T = x.shape # generate feature vectors logx = self.bn0(x[..., 0, :].abs().square().clip(1e-6).log()) # [B, F, T] ph = x[..., 1:, :] / (x[..., 0, None, :] + 1e-6) # [B, F, M-1, T] ph /= torch.abs(ph).clip(1e-6) ph = rearrange(torch.view_as_real(ph), "b f m t c -> b (f m c) t") h = torch.concat([logx, ph], dim=1) # [B, C, T] # main convolutions h = self.cnv(h) z_mu, z_sig_ = self.cnv_z(h) # [B, 2, D, N, T] qz = Normal(z_mu, fn.softplus(z_sig_) + 1e-6) if distribution else z_mu return qz