Source code for sbss.common.distributions

# Copyright (C) 2025 National Institute of Advanced Industrial Science and Technology (AIST)
# SPDX-License-Identifier: MIT

import torch
from torch.distributions import RelaxedBernoulli, SigmoidTransform
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.relaxed_bernoulli import LogitRelaxedBernoulli


[docs] class ApproxBernoulli(RelaxedBernoulli): """Approximation of Bernoulli distribution using a relaxed Bernoulli formulation. This class provides a straight-through estimator variant of the Relaxed Bernoulli distribution by thresholding at 0.5 while maintaining gradient flow for reparameterized samples. Args: temperature (float): Relaxation temperature parameter. probs (torch.Tensor, optional): Probability of success. logits (torch.Tensor, optional): Log-odds of success. validate_args (bool, optional): Whether to validate the input arguments. """
[docs] def __init__(self, temperature, probs=None, logits=None, validate_args=None): base_dist = LogitRelaxedBernoulli(temperature, probs, logits, validate_args=validate_args) super(RelaxedBernoulli, self).__init__(base_dist, SigmoidTransform(), validate_args=validate_args)
def rsample(self, sample_shape=torch.Size()): # noqa x = super().rsample(sample_shape) return x - x.detach() + (x > 0.5).to(x.dtype)
def cmvlgamma(nu, M): lp = M * (M - 1) / 2 * torch.log(torch.tensor(torch.pi)) + torch.lgamma( nu[..., None] - torch.arange(M, device="cuda") ).sum(-1) return lp.to("cuda")
[docs] class ComplexWishart(ExponentialFamily): """Complex Wishart distribution. The complex Wishart distribution is a matrix-valued probability distribution commonly used to model covariance matrices in the complex domain. Args: nu (float or torch.Tensor): Degrees of freedom of the distribution. covariance_matrix (torch.Tensor): Positive-definite covariance matrix. """
[docs] def __init__(self, nu, covariance_matrix): self.nu = torch.as_tensor(nu, dtype=torch.float32, device=covariance_matrix.device) self.covariance_matrix = covariance_matrix self.M = covariance_matrix.shape[-1]
def log_prob(self, Sig): Psi = self.covariance_matrix Psii = torch.linalg.inv(Psi) trPsiiSig = torch.einsum("...mn,...nm->...", Psii, Sig).real.clip(min=1e-6) logp = (self.nu - self.M) * torch.linalg.slogdet(Sig)[1] logp -= self.nu * torch.linalg.slogdet(Psi)[1] logp -= trPsiiSig logp -= cmvlgamma(self.nu, self.M) return logp
[docs] class ComplexInverseWishart(ExponentialFamily): """Complex Inverse Wishart distribution. The complex inverse Wishart distribution is the conjugate prior for complex covariance matrices, often used in Bayesian signal processing. Args: nu (float or torch.Tensor): Degrees of freedom of the distribution. covariance_matrix (torch.Tensor): Positive-definite covariance matrix. """
[docs] def __init__(self, nu, covariance_matrix): self.nu = torch.as_tensor(nu, dtype=torch.float32, device=covariance_matrix.device) self.covariance_matrix = covariance_matrix self.M = covariance_matrix.shape[-1]
def log_prob(self, Sig): Psi = self.covariance_matrix Sigi = torch.linalg.inv(Sig) trPsiSigi = torch.einsum("...mn,...nm->...", Psi, Sigi).real.clip(min=1e-6) logp = -(self.nu + self.M) * torch.linalg.slogdet(Sig)[1] logp += self.nu * torch.linalg.slogdet(Psi)[1] logp -= trPsiSigi logp -= cmvlgamma(self.nu, self.M) return logp
[docs] class FastComplexWishart(ComplexWishart): """Fast approximation of Complex Wishart distribution. This variant omits constant normalization terms for faster computation while preserving dependence on covariance matrices. Args: nu (float or torch.Tensor): Degrees of freedom of the distribution. covariance_matrix (torch.Tensor): Positive-definite covariance matrix. """ def log_prob(self, Sig): Psi = self.covariance_matrix Psii = torch.linalg.inv(Psi) trPsiiSig = torch.einsum("...mn,...nm->...", Psii, Sig).real.clip(min=1e-6) logp = (self.nu - self.M) * torch.linalg.slogdet(Sig)[1] logp -= trPsiiSig return logp
[docs] class FastComplexInverseWishart(ComplexInverseWishart): """Fast approximation of Complex Inverse Wishart distribution. This variant omits normalization constants for efficiency while maintaining the primary statistical structure. Args: nu (float or torch.Tensor): Degrees of freedom of the distribution. covariance_matrix (torch.Tensor): Positive-definite covariance matrix. """
[docs] def __init__(self, nu, covariance_matrix): self.nu = torch.as_tensor(nu, dtype=torch.float32, device=covariance_matrix.device) self.covariance_matrix = covariance_matrix self.M = covariance_matrix.shape[-1]
def log_prob(self, Sig): Psi = self.covariance_matrix Sigi = torch.linalg.inv(Sig) trPsiSigi = torch.einsum("...mn,...nm->...", Psi, Sigi).real.clip(min=1e-6) logp = -(self.nu + self.M) * torch.linalg.slogdet(Sig)[1] logp -= trPsiSigi return logp