Source code for sbss.nfca.callbacks.xt_visualizer

# Copyright (C) 2025 National Institute of Advanced Industrial Science and Technology (AIST)
# SPDX-License-Identifier: MIT

from typing import Any

import matplotlib.pyplot as plt
import numpy as np

import torch

import lightning as lt


[docs] class XtVisualizerCallback(lt.Callback): """Callback to visualize the time-frequency energy (Xt) of the model output during validation. This callback generates spectrogram-like plots from the `dump.xt` tensor and logs them to TensorBoard at the start and end of validation. It assumes that the Lightning module has a `dump` attribute containing `xt` (the time-frequency representation) and `data_name`. Args: trainer (lightning.Trainer): PyTorch Lightning trainer handling the training/validation loop. pl_module (Any): Lightning module being trained or validated. tag (str, optional): Label prefix for TensorBoard logging. Defaults to `"training"`. Returns: None """ def on_validation_start(self, trainer: lt.Trainer, pl_module: Any, tag: str = "training"): if not hasattr(pl_module, "dump"): return dump = pl_module.dump prefix = f"{tag}" if hasattr(dump, "data_name"): prefix += f"/{dump.data_name}" # select a random batch element B, *_ = dump.xpwr.shape b = np.random.choice(B) dump = dump.__class__(**{k: v[b].cpu() if isinstance(v, torch.Tensor) else v for k, v in vars(dump).items()}) logxt = dump.xt.log10().mul_(10).cpu().numpy() _, M, _ = logxt.shape # plot xt fig, axs = plt.subplots(M, 1, sharex=True, figsize=[8, 1.5 * M]) vmax = logxt.max() vmin = vmax - 80 for m, ax in enumerate(axs): ax.imshow(logxt[..., m, :], origin="lower", aspect="auto", vmin=vmin, vmax=vmax) fig.tight_layout(pad=0.1) pl_module.logger.experiment.add_figure(f"{prefix}/xt", fig, global_step=trainer.current_epoch) plt.close(fig) def on_validation_end(self, trainer: lt.Trainer, pl_module: Any): self.on_validation_start(trainer, pl_module, tag="validation")