Source code for sbss.common.diagonalizers.iss_diagonalizer

# Copyright (C) 2025 National Institute of Advanced Industrial Science and Technology (AIST)
# SPDX-License-Identifier: MIT

import torch
from torch import nn


[docs] class ISSDiagonalizer(nn.Module): """Iterative Source Steering (ISS) algorithm for diagonalizing spatial covariance matrices. This module performs iterative updates of the demixing matrix ``Q`` based on the input mixture ``x`` and inverse power spectrum density ``r``. It can be used as part of spatial source separation algorithms such as IVA and FastMNMF. Args: n_iter (int, optional): Number of ISS iterations to perform. Defaults to 1. eps (float, optional): Regularization constant added to covariance diagonals to prevent numerical instability. Defaults to 1e-6. eps2 (float, optional): Small value used to clip denominator terms during normalization to avoid division by zero. Defaults to 1e-6. eps3 (float, optional): Minimum clipping value for the inverse power estimates in ``r`` to stabilize computations. Defaults to 1e-3. norm_q (bool, optional): Whether to normalize the demixing matrix ``Q`` at the end of the iteration. Defaults to False. Returns: nn.Module: A PyTorch module that outputs the updated demixing matrix ``Q`` and the corresponding source power estimates ``xt`` after applying the ISS updates. """
[docs] def __init__( self, n_iter: int = 1, eps: float = 1e-6, eps2: float = 1e-6, eps3: float = 1e-3, norm_q: bool = False ): super().__init__() self.n_iter = n_iter self.eps = eps self.eps2 = eps2 self.eps3 = eps3 self.norm_q = norm_q
def forward(self, r, Q, x): """ Parameters ---------- r : (B, F, M, T) Tensor Q : (B, F, M, M) Tensor x : (B, F, M, T) Tensor """ _, _, M, T = x.shape V = torch.einsum("...kt,...mt,...nt->...kmn", r.clip(self.eps3), x, x.conj()) / T trV = torch.einsum("...mm->...", V)[..., None, None].real V = V + trV.maximum(torch.ones_like(trV)).to(V.dtype) * self.eps * torch.eye(M, device="cuda") for _ in range(self.n_iter): for k in range(M): q = Q[..., k, :] Vq = torch.einsum("...kmn,...n->...km", V, q.conj()) qVq = torch.einsum("...m,...km->...k", q, Vq).real.clip(self.eps2) v = torch.einsum("...km,...km->...k", Q, Vq) / qVq.to(x.dtype) v[..., k] = 1 - qVq[..., k] ** -0.5 Q = Q - torch.einsum("...m,...n->...mn", v, q) Qx = Q @ x xt = Qx.real**2 + Qx.imag**2 # torch.abs(Qx) ** 2 if self.norm_q: scale = xt.mean(dim=(1, 2, 3), keepdim=True) xt = xt / scale Q = Q / scale.clip(1e-6).sqrt().to(x.dtype) return Q, xt